381 lines
12 KiB
Python
Executable File
381 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Remove duplicate annotations from YOLO label files.
|
|
|
|
Searches for overlapping bounding boxes with overlap >= 10% and removes
|
|
the smaller box, keeping the larger one.
|
|
|
|
Usage:
|
|
python scripts/remove_duplicate_annotations.py --input output/snapshots/background/labels
|
|
python scripts/remove_duplicate_annotations.py --input output/snapshots/background/labels --overlap-threshold 0.15
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Optional
|
|
import logging
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BBox:
|
|
"""Bounding box representation."""
|
|
|
|
def __init__(self, x_center: float, y_center: float, width: float, height: float):
|
|
"""
|
|
Initialize from YOLO format (normalized coordinates).
|
|
|
|
Args:
|
|
x_center: Normalized x center (0-1)
|
|
y_center: Normalized y center (0-1)
|
|
width: Normalized width (0-1)
|
|
height: Normalized height (0-1)
|
|
"""
|
|
self.x_center = x_center
|
|
self.y_center = y_center
|
|
self.width = width
|
|
self.height = height
|
|
|
|
def to_xyxy(self) -> Tuple[float, float, float, float]:
|
|
"""Convert to absolute coordinates (x1, y1, x2, y2) in normalized space."""
|
|
x1 = self.x_center - self.width / 2
|
|
y1 = self.y_center - self.height / 2
|
|
x2 = self.x_center + self.width / 2
|
|
y2 = self.y_center + self.height / 2
|
|
return (x1, y1, x2, y2)
|
|
|
|
def area(self) -> float:
|
|
"""Calculate area in normalized space."""
|
|
return self.width * self.height
|
|
|
|
def intersection_area(self, other: 'BBox') -> float:
|
|
"""Calculate intersection area with another box."""
|
|
x1_1, y1_1, x2_1, y2_1 = self.to_xyxy()
|
|
x1_2, y1_2, x2_2, y2_2 = other.to_xyxy()
|
|
|
|
# Calculate intersection
|
|
x1_inter = max(x1_1, x1_2)
|
|
y1_inter = max(y1_1, y1_2)
|
|
x2_inter = min(x2_1, x2_2)
|
|
y2_inter = min(y2_1, y2_2)
|
|
|
|
if x2_inter <= x1_inter or y2_inter <= y1_inter:
|
|
return 0.0
|
|
|
|
inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
|
return inter_area
|
|
|
|
def overlap_ratio(self, other: 'BBox') -> float:
|
|
"""
|
|
Calculate overlap ratio as intersection / min(area1, area2).
|
|
|
|
This gives the percentage of the smaller box that overlaps.
|
|
"""
|
|
inter_area = self.intersection_area(other)
|
|
if inter_area == 0:
|
|
return 0.0
|
|
|
|
min_area = min(self.area(), other.area())
|
|
if min_area == 0:
|
|
return 0.0
|
|
|
|
return inter_area / min_area
|
|
|
|
|
|
class Annotation:
|
|
"""Single annotation entry."""
|
|
|
|
def __init__(self, class_id: int, bbox: BBox, line_index: int):
|
|
self.class_id = class_id
|
|
self.bbox = bbox
|
|
self.line_index = line_index # Original line index in file
|
|
|
|
def to_yolo_line(self) -> str:
|
|
"""Convert back to YOLO format line."""
|
|
return f"{self.class_id} {self.bbox.x_center:.6f} {self.bbox.y_center:.6f} {self.bbox.width:.6f} {self.bbox.height:.6f}"
|
|
|
|
|
|
def parse_yolo_label(label_path: Path) -> List[Annotation]:
|
|
"""
|
|
Parse YOLO format label file.
|
|
|
|
Format: class_id x_center y_center width height (all normalized 0-1)
|
|
|
|
Returns:
|
|
List of Annotation objects
|
|
"""
|
|
annotations = []
|
|
|
|
if not label_path.exists():
|
|
logger.warning(f"Label file not found: {label_path}")
|
|
return annotations
|
|
|
|
try:
|
|
with open(label_path, 'r') as f:
|
|
lines = f.readlines()
|
|
|
|
for idx, line in enumerate(lines):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
parts = line.split()
|
|
if len(parts) != 5:
|
|
logger.warning(f"Invalid line format in {label_path}:{idx+1}: {line}")
|
|
continue
|
|
|
|
try:
|
|
class_id = int(parts[0])
|
|
x_center = float(parts[1])
|
|
y_center = float(parts[2])
|
|
width = float(parts[3])
|
|
height = float(parts[4])
|
|
|
|
# Validate normalized coordinates
|
|
if not (0 <= x_center <= 1 and 0 <= y_center <= 1 and
|
|
0 <= width <= 1 and 0 <= height <= 1):
|
|
logger.warning(f"Invalid coordinates in {label_path}:{idx+1}: {line}")
|
|
continue
|
|
|
|
bbox = BBox(x_center, y_center, width, height)
|
|
annotation = Annotation(class_id, bbox, idx)
|
|
annotations.append(annotation)
|
|
|
|
except ValueError as e:
|
|
logger.warning(f"Error parsing line in {label_path}:{idx+1}: {line} - {e}")
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error reading {label_path}: {e}")
|
|
|
|
return annotations
|
|
|
|
|
|
def find_duplicates(annotations: List[Annotation], overlap_threshold: float = 0.10) -> List[int]:
|
|
"""
|
|
Find indices of annotations to remove (duplicates).
|
|
|
|
Args:
|
|
annotations: List of annotations
|
|
overlap_threshold: Minimum overlap ratio to consider as duplicate (default: 0.10 = 10%)
|
|
|
|
Returns:
|
|
List of indices to remove (sorted in descending order for safe removal)
|
|
"""
|
|
to_remove = set()
|
|
n = len(annotations)
|
|
|
|
for i in range(n):
|
|
if i in to_remove:
|
|
continue
|
|
|
|
for j in range(i + 1, n):
|
|
if j in to_remove:
|
|
continue
|
|
|
|
# Check if same class (optional - you might want to check across classes too)
|
|
# For now, we'll check all overlaps regardless of class
|
|
overlap = annotations[i].bbox.overlap_ratio(annotations[j].bbox)
|
|
|
|
if overlap >= overlap_threshold:
|
|
# Remove the smaller box
|
|
area_i = annotations[i].bbox.area()
|
|
area_j = annotations[j].bbox.area()
|
|
|
|
if area_i < area_j:
|
|
to_remove.add(i)
|
|
logger.debug(f"Marking annotation {i} for removal (smaller, overlap={overlap:.2%})")
|
|
break # i is removed, move to next i
|
|
else:
|
|
to_remove.add(j)
|
|
logger.debug(f"Marking annotation {j} for removal (smaller, overlap={overlap:.2%})")
|
|
# Continue checking j+1 against i
|
|
|
|
return sorted(to_remove, reverse=True)
|
|
|
|
|
|
def remove_duplicates_from_file(label_path: Path, overlap_threshold: float = 0.10, dry_run: bool = False) -> Tuple[int, int]:
|
|
"""
|
|
Remove duplicate annotations from a label file.
|
|
|
|
Args:
|
|
label_path: Path to label file
|
|
overlap_threshold: Minimum overlap ratio to consider as duplicate
|
|
dry_run: If True, don't modify files, just report
|
|
|
|
Returns:
|
|
Tuple of (original_count, removed_count)
|
|
"""
|
|
annotations = parse_yolo_label(label_path)
|
|
original_count = len(annotations)
|
|
|
|
if original_count == 0:
|
|
return (0, 0)
|
|
|
|
to_remove = find_duplicates(annotations, overlap_threshold)
|
|
removed_count = len(to_remove)
|
|
|
|
if removed_count == 0:
|
|
return (original_count, 0)
|
|
|
|
if dry_run:
|
|
logger.info(f"[DRY RUN] {label_path.name}: Would remove {removed_count}/{original_count} annotations")
|
|
for idx in reversed(to_remove):
|
|
logger.debug(f" Would remove line {annotations[idx].line_index + 1}: {annotations[idx].to_yolo_line()}")
|
|
return (original_count, removed_count)
|
|
|
|
# Remove duplicates (in reverse order to maintain indices)
|
|
for idx in to_remove:
|
|
annotations.pop(idx)
|
|
|
|
# Write back to file
|
|
try:
|
|
with open(label_path, 'w') as f:
|
|
for ann in annotations:
|
|
f.write(ann.to_yolo_line() + '\n')
|
|
|
|
logger.info(f"{label_path.name}: Removed {removed_count}/{original_count} duplicate annotations")
|
|
return (original_count, removed_count)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error writing {label_path}: {e}")
|
|
return (original_count, 0)
|
|
|
|
|
|
def process_directory(
|
|
input_dir: Path,
|
|
overlap_threshold: float = 0.10,
|
|
dry_run: bool = False,
|
|
recursive: bool = True
|
|
) -> None:
|
|
"""
|
|
Process all label files in a directory.
|
|
|
|
Args:
|
|
input_dir: Directory containing label files
|
|
overlap_threshold: Minimum overlap ratio to consider as duplicate
|
|
dry_run: If True, don't modify files, just report
|
|
recursive: If True, search recursively in subdirectories
|
|
"""
|
|
if not input_dir.exists():
|
|
logger.error(f"Input directory does not exist: {input_dir}")
|
|
return
|
|
|
|
# Find all .txt files
|
|
if recursive:
|
|
label_files = list(input_dir.rglob("*.txt"))
|
|
else:
|
|
label_files = list(input_dir.glob("*.txt"))
|
|
|
|
if not label_files:
|
|
logger.warning(f"No .txt files found in {input_dir}")
|
|
return
|
|
|
|
logger.info(f"Found {len(label_files)} label files")
|
|
logger.info(f"Overlap threshold: {overlap_threshold * 100:.1f}%")
|
|
if dry_run:
|
|
logger.info("DRY RUN MODE - No files will be modified")
|
|
|
|
total_original = 0
|
|
total_removed = 0
|
|
files_modified = 0
|
|
|
|
for label_file in label_files:
|
|
original, removed = remove_duplicates_from_file(label_file, overlap_threshold, dry_run)
|
|
total_original += original
|
|
total_removed += removed
|
|
if removed > 0:
|
|
files_modified += 1
|
|
|
|
logger.info("=" * 60)
|
|
logger.info(f"Summary:")
|
|
logger.info(f" Files processed: {len(label_files)}")
|
|
logger.info(f" Files with duplicates: {files_modified}")
|
|
logger.info(f" Total annotations: {total_original}")
|
|
logger.info(f" Duplicates removed: {total_removed}")
|
|
logger.info(f" Remaining annotations: {total_original - total_removed}")
|
|
if total_original > 0:
|
|
logger.info(f" Removal rate: {total_removed / total_original * 100:.2f}%")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Remove duplicate annotations from YOLO label files",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Process all labels in directory (dry run)
|
|
python scripts/remove_duplicate_annotations.py \\
|
|
--input output/snapshots/background/labels \\
|
|
--dry-run
|
|
|
|
# Process and remove duplicates
|
|
python scripts/remove_duplicate_annotations.py \\
|
|
--input output/snapshots/background/labels
|
|
|
|
# Custom overlap threshold (15%)
|
|
python scripts/remove_duplicate_annotations.py \\
|
|
--input output/snapshots/background/labels \\
|
|
--overlap-threshold 0.15
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--input', '-i',
|
|
type=str,
|
|
required=True,
|
|
help='Input directory containing label files (.txt)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--overlap-threshold', '-t',
|
|
type=float,
|
|
default=0.10,
|
|
help='Minimum overlap ratio to consider as duplicate (default: 0.10 = 10%%)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--dry-run', '-d',
|
|
action='store_true',
|
|
help='Dry run mode - report what would be removed without modifying files'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--no-recursive',
|
|
action='store_true',
|
|
help='Don\'t search recursively in subdirectories'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--debug',
|
|
action='store_true',
|
|
help='Enable debug logging'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.debug:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
input_dir = Path(args.input)
|
|
|
|
if not (0 <= args.overlap_threshold <= 1):
|
|
logger.error("Overlap threshold must be between 0 and 1")
|
|
sys.exit(1)
|
|
|
|
process_directory(
|
|
input_dir=input_dir,
|
|
overlap_threshold=args.overlap_threshold,
|
|
dry_run=args.dry_run,
|
|
recursive=not args.no_recursive
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|