Files
dataset-yolo-script/sam2-cpu/scripts/remove_duplicate_annotations.py
2026-02-04 15:29:36 +07:00

381 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Remove duplicate annotations from YOLO label files.
Searches for overlapping bounding boxes with overlap >= 10% and removes
the smaller box, keeping the larger one.
Usage:
python scripts/remove_duplicate_annotations.py --input output/snapshots/background/labels
python scripts/remove_duplicate_annotations.py --input output/snapshots/background/labels --overlap-threshold 0.15
"""
import argparse
import sys
from pathlib import Path
from typing import List, Tuple, Optional
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class BBox:
"""Bounding box representation."""
def __init__(self, x_center: float, y_center: float, width: float, height: float):
"""
Initialize from YOLO format (normalized coordinates).
Args:
x_center: Normalized x center (0-1)
y_center: Normalized y center (0-1)
width: Normalized width (0-1)
height: Normalized height (0-1)
"""
self.x_center = x_center
self.y_center = y_center
self.width = width
self.height = height
def to_xyxy(self) -> Tuple[float, float, float, float]:
"""Convert to absolute coordinates (x1, y1, x2, y2) in normalized space."""
x1 = self.x_center - self.width / 2
y1 = self.y_center - self.height / 2
x2 = self.x_center + self.width / 2
y2 = self.y_center + self.height / 2
return (x1, y1, x2, y2)
def area(self) -> float:
"""Calculate area in normalized space."""
return self.width * self.height
def intersection_area(self, other: 'BBox') -> float:
"""Calculate intersection area with another box."""
x1_1, y1_1, x2_1, y2_1 = self.to_xyxy()
x1_2, y1_2, x2_2, y2_2 = other.to_xyxy()
# Calculate intersection
x1_inter = max(x1_1, x1_2)
y1_inter = max(y1_1, y1_2)
x2_inter = min(x2_1, x2_2)
y2_inter = min(y2_1, y2_2)
if x2_inter <= x1_inter or y2_inter <= y1_inter:
return 0.0
inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
return inter_area
def overlap_ratio(self, other: 'BBox') -> float:
"""
Calculate overlap ratio as intersection / min(area1, area2).
This gives the percentage of the smaller box that overlaps.
"""
inter_area = self.intersection_area(other)
if inter_area == 0:
return 0.0
min_area = min(self.area(), other.area())
if min_area == 0:
return 0.0
return inter_area / min_area
class Annotation:
"""Single annotation entry."""
def __init__(self, class_id: int, bbox: BBox, line_index: int):
self.class_id = class_id
self.bbox = bbox
self.line_index = line_index # Original line index in file
def to_yolo_line(self) -> str:
"""Convert back to YOLO format line."""
return f"{self.class_id} {self.bbox.x_center:.6f} {self.bbox.y_center:.6f} {self.bbox.width:.6f} {self.bbox.height:.6f}"
def parse_yolo_label(label_path: Path) -> List[Annotation]:
"""
Parse YOLO format label file.
Format: class_id x_center y_center width height (all normalized 0-1)
Returns:
List of Annotation objects
"""
annotations = []
if not label_path.exists():
logger.warning(f"Label file not found: {label_path}")
return annotations
try:
with open(label_path, 'r') as f:
lines = f.readlines()
for idx, line in enumerate(lines):
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 5:
logger.warning(f"Invalid line format in {label_path}:{idx+1}: {line}")
continue
try:
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
# Validate normalized coordinates
if not (0 <= x_center <= 1 and 0 <= y_center <= 1 and
0 <= width <= 1 and 0 <= height <= 1):
logger.warning(f"Invalid coordinates in {label_path}:{idx+1}: {line}")
continue
bbox = BBox(x_center, y_center, width, height)
annotation = Annotation(class_id, bbox, idx)
annotations.append(annotation)
except ValueError as e:
logger.warning(f"Error parsing line in {label_path}:{idx+1}: {line} - {e}")
continue
except Exception as e:
logger.error(f"Error reading {label_path}: {e}")
return annotations
def find_duplicates(annotations: List[Annotation], overlap_threshold: float = 0.10) -> List[int]:
"""
Find indices of annotations to remove (duplicates).
Args:
annotations: List of annotations
overlap_threshold: Minimum overlap ratio to consider as duplicate (default: 0.10 = 10%)
Returns:
List of indices to remove (sorted in descending order for safe removal)
"""
to_remove = set()
n = len(annotations)
for i in range(n):
if i in to_remove:
continue
for j in range(i + 1, n):
if j in to_remove:
continue
# Check if same class (optional - you might want to check across classes too)
# For now, we'll check all overlaps regardless of class
overlap = annotations[i].bbox.overlap_ratio(annotations[j].bbox)
if overlap >= overlap_threshold:
# Remove the smaller box
area_i = annotations[i].bbox.area()
area_j = annotations[j].bbox.area()
if area_i < area_j:
to_remove.add(i)
logger.debug(f"Marking annotation {i} for removal (smaller, overlap={overlap:.2%})")
break # i is removed, move to next i
else:
to_remove.add(j)
logger.debug(f"Marking annotation {j} for removal (smaller, overlap={overlap:.2%})")
# Continue checking j+1 against i
return sorted(to_remove, reverse=True)
def remove_duplicates_from_file(label_path: Path, overlap_threshold: float = 0.10, dry_run: bool = False) -> Tuple[int, int]:
"""
Remove duplicate annotations from a label file.
Args:
label_path: Path to label file
overlap_threshold: Minimum overlap ratio to consider as duplicate
dry_run: If True, don't modify files, just report
Returns:
Tuple of (original_count, removed_count)
"""
annotations = parse_yolo_label(label_path)
original_count = len(annotations)
if original_count == 0:
return (0, 0)
to_remove = find_duplicates(annotations, overlap_threshold)
removed_count = len(to_remove)
if removed_count == 0:
return (original_count, 0)
if dry_run:
logger.info(f"[DRY RUN] {label_path.name}: Would remove {removed_count}/{original_count} annotations")
for idx in reversed(to_remove):
logger.debug(f" Would remove line {annotations[idx].line_index + 1}: {annotations[idx].to_yolo_line()}")
return (original_count, removed_count)
# Remove duplicates (in reverse order to maintain indices)
for idx in to_remove:
annotations.pop(idx)
# Write back to file
try:
with open(label_path, 'w') as f:
for ann in annotations:
f.write(ann.to_yolo_line() + '\n')
logger.info(f"{label_path.name}: Removed {removed_count}/{original_count} duplicate annotations")
return (original_count, removed_count)
except Exception as e:
logger.error(f"Error writing {label_path}: {e}")
return (original_count, 0)
def process_directory(
input_dir: Path,
overlap_threshold: float = 0.10,
dry_run: bool = False,
recursive: bool = True
) -> None:
"""
Process all label files in a directory.
Args:
input_dir: Directory containing label files
overlap_threshold: Minimum overlap ratio to consider as duplicate
dry_run: If True, don't modify files, just report
recursive: If True, search recursively in subdirectories
"""
if not input_dir.exists():
logger.error(f"Input directory does not exist: {input_dir}")
return
# Find all .txt files
if recursive:
label_files = list(input_dir.rglob("*.txt"))
else:
label_files = list(input_dir.glob("*.txt"))
if not label_files:
logger.warning(f"No .txt files found in {input_dir}")
return
logger.info(f"Found {len(label_files)} label files")
logger.info(f"Overlap threshold: {overlap_threshold * 100:.1f}%")
if dry_run:
logger.info("DRY RUN MODE - No files will be modified")
total_original = 0
total_removed = 0
files_modified = 0
for label_file in label_files:
original, removed = remove_duplicates_from_file(label_file, overlap_threshold, dry_run)
total_original += original
total_removed += removed
if removed > 0:
files_modified += 1
logger.info("=" * 60)
logger.info(f"Summary:")
logger.info(f" Files processed: {len(label_files)}")
logger.info(f" Files with duplicates: {files_modified}")
logger.info(f" Total annotations: {total_original}")
logger.info(f" Duplicates removed: {total_removed}")
logger.info(f" Remaining annotations: {total_original - total_removed}")
if total_original > 0:
logger.info(f" Removal rate: {total_removed / total_original * 100:.2f}%")
def main():
parser = argparse.ArgumentParser(
description="Remove duplicate annotations from YOLO label files",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Process all labels in directory (dry run)
python scripts/remove_duplicate_annotations.py \\
--input output/snapshots/background/labels \\
--dry-run
# Process and remove duplicates
python scripts/remove_duplicate_annotations.py \\
--input output/snapshots/background/labels
# Custom overlap threshold (15%)
python scripts/remove_duplicate_annotations.py \\
--input output/snapshots/background/labels \\
--overlap-threshold 0.15
"""
)
parser.add_argument(
'--input', '-i',
type=str,
required=True,
help='Input directory containing label files (.txt)'
)
parser.add_argument(
'--overlap-threshold', '-t',
type=float,
default=0.10,
help='Minimum overlap ratio to consider as duplicate (default: 0.10 = 10%%)'
)
parser.add_argument(
'--dry-run', '-d',
action='store_true',
help='Dry run mode - report what would be removed without modifying files'
)
parser.add_argument(
'--no-recursive',
action='store_true',
help='Don\'t search recursively in subdirectories'
)
parser.add_argument(
'--debug',
action='store_true',
help='Enable debug logging'
)
args = parser.parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
input_dir = Path(args.input)
if not (0 <= args.overlap_threshold <= 1):
logger.error("Overlap threshold must be between 0 and 1")
sys.exit(1)
process_directory(
input_dir=input_dir,
overlap_threshold=args.overlap_threshold,
dry_run=args.dry_run,
recursive=not args.no_recursive
)
if __name__ == '__main__':
main()