488 lines
16 KiB
Python
488 lines
16 KiB
Python
"""
|
|
Core YOLO-based video annotator.
|
|
|
|
Uses pretrained YOLOv9t to automatically detect and annotate objects in video frames.
|
|
"""
|
|
|
|
import cv2
|
|
import yaml
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from tqdm import tqdm
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
HAS_ULTRALYTICS = True
|
|
except ImportError:
|
|
HAS_ULTRALYTICS = False
|
|
|
|
|
|
@dataclass
|
|
class BBox:
|
|
"""Bounding box representation."""
|
|
x1: float # Top-left x (pixels)
|
|
y1: float # Top-left y (pixels)
|
|
x2: float # Bottom-right x (pixels)
|
|
y2: float # Bottom-right y (pixels)
|
|
|
|
def to_yolo(self, img_width: int, img_height: int) -> Tuple[float, float, float, float]:
|
|
"""Convert to YOLO format (x_center, y_center, width, height) normalized."""
|
|
x_center = ((self.x1 + self.x2) / 2) / img_width
|
|
y_center = ((self.y1 + self.y2) / 2) / img_height
|
|
width = (self.x2 - self.x1) / img_width
|
|
height = (self.y2 - self.y1) / img_height
|
|
return (x_center, y_center, width, height)
|
|
|
|
def area(self) -> float:
|
|
"""Calculate bbox area in pixels."""
|
|
return (self.x2 - self.x1) * (self.y2 - self.y1)
|
|
|
|
@property
|
|
def width(self) -> float:
|
|
return self.x2 - self.x1
|
|
|
|
@property
|
|
def height(self) -> float:
|
|
return self.y2 - self.y1
|
|
|
|
|
|
@dataclass
|
|
class Detection:
|
|
"""Single object detection."""
|
|
class_id: int
|
|
class_name: str
|
|
confidence: float
|
|
bbox: BBox
|
|
frame_id: int
|
|
timestamp: float = 0.0
|
|
track_id: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class AnnotationResult:
|
|
"""Result of annotating a video."""
|
|
video_path: str
|
|
total_frames: int
|
|
processed_frames: int
|
|
total_detections: int
|
|
detections_per_frame: Dict[int, List[Detection]] = field(default_factory=dict)
|
|
output_dir: Optional[str] = None
|
|
created_at: datetime = field(default_factory=datetime.now)
|
|
|
|
|
|
class YOLOAnnotator:
|
|
"""YOLO-based automatic video annotator."""
|
|
|
|
def __init__(self, config_path: Optional[str] = None, config: Optional[Dict] = None):
|
|
"""
|
|
Initialize annotator.
|
|
|
|
Args:
|
|
config_path: Path to YAML config file
|
|
config: Config dictionary (overrides config_path)
|
|
"""
|
|
if not HAS_ULTRALYTICS:
|
|
raise ImportError("ultralytics package required. Install with: pip install ultralytics")
|
|
|
|
self.config = self._load_config(config_path, config)
|
|
self.model = None
|
|
self.class_names = self.config.get('class_names', {})
|
|
|
|
def _load_config(self, config_path: Optional[str], config: Optional[Dict]) -> Dict:
|
|
"""Load configuration from file or dict."""
|
|
if config is not None:
|
|
return config
|
|
|
|
if config_path is not None:
|
|
with open(config_path, 'r') as f:
|
|
return yaml.safe_load(f)
|
|
|
|
# Default config
|
|
return {
|
|
'model': {
|
|
'path': 'yolov9t.pt',
|
|
'device': 'cuda',
|
|
'conf_threshold': 0.25,
|
|
'iou_threshold': 0.45,
|
|
},
|
|
'video': {
|
|
'sample_fps': 2,
|
|
'max_frames': None,
|
|
'start_time': 0,
|
|
'end_time': None,
|
|
'resize': None,
|
|
},
|
|
'detection': {
|
|
'classes': None,
|
|
'min_confidence': 0.3,
|
|
'min_area': 100,
|
|
'max_area': None,
|
|
'min_size': 0.01,
|
|
},
|
|
'output': {
|
|
'directory': 'output/annotations',
|
|
'save_snapshots': True,
|
|
'save_labels': True,
|
|
'save_debug': True,
|
|
'save_manifest': True,
|
|
'image_format': 'jpg',
|
|
'image_quality': 95,
|
|
},
|
|
}
|
|
|
|
def load_model(self, model_path: Optional[str] = None, device: Optional[str] = None) -> None:
|
|
"""
|
|
Load YOLOv9t model.
|
|
|
|
Args:
|
|
model_path: Path to model weights (overrides config)
|
|
device: Device to use (overrides config)
|
|
"""
|
|
model_cfg = self.config.get('model', {})
|
|
path = model_path or model_cfg.get('path', 'yolov9t.pt')
|
|
dev = device or model_cfg.get('device', 'cuda')
|
|
|
|
print(f"Loading model: {path}")
|
|
self.model = YOLO(path)
|
|
self.model.to(dev)
|
|
print(f"Model loaded on {dev}")
|
|
|
|
def process_video(
|
|
self,
|
|
video_path: Optional[str] = None,
|
|
output_dir: Optional[str] = None
|
|
) -> AnnotationResult:
|
|
"""
|
|
Process entire video and generate annotations.
|
|
|
|
Args:
|
|
video_path: Path to video file (overrides config)
|
|
output_dir: Output directory (overrides config)
|
|
|
|
Returns:
|
|
AnnotationResult with all detections
|
|
"""
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
video_cfg = self.config.get('video', {})
|
|
output_cfg = self.config.get('output', {})
|
|
|
|
source = video_path or video_cfg.get('source')
|
|
if source is None:
|
|
raise ValueError("No video source specified")
|
|
|
|
out_dir = output_dir or output_cfg.get('directory', 'output/annotations')
|
|
out_path = Path(out_dir)
|
|
|
|
# Create output directories
|
|
if output_cfg.get('save_snapshots', True):
|
|
(out_path / 'images').mkdir(parents=True, exist_ok=True)
|
|
if output_cfg.get('save_labels', True):
|
|
(out_path / 'labels').mkdir(parents=True, exist_ok=True)
|
|
if output_cfg.get('save_debug', True):
|
|
(out_path / 'debug').mkdir(parents=True, exist_ok=True)
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(source)
|
|
if not cap.isOpened():
|
|
raise ValueError(f"Cannot open video: {source}")
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
# Calculate frame interval
|
|
sample_fps = video_cfg.get('sample_fps')
|
|
if sample_fps and sample_fps < fps:
|
|
frame_interval = int(fps / sample_fps)
|
|
else:
|
|
frame_interval = 1
|
|
|
|
# Frame range
|
|
start_time = video_cfg.get('start_time', 0)
|
|
end_time = video_cfg.get('end_time')
|
|
max_frames = video_cfg.get('max_frames')
|
|
|
|
start_frame = int(start_time * fps)
|
|
end_frame = int(end_time * fps) if end_time else total_frames
|
|
end_frame = min(end_frame, total_frames)
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
|
|
# Initialize result
|
|
result = AnnotationResult(
|
|
video_path=source,
|
|
total_frames=total_frames,
|
|
processed_frames=0,
|
|
total_detections=0,
|
|
output_dir=str(out_path),
|
|
)
|
|
|
|
# Get config values
|
|
model_cfg = self.config.get('model', {})
|
|
conf_thresh = model_cfg.get('conf_threshold', 0.25)
|
|
iou_thresh = model_cfg.get('iou_threshold', 0.45)
|
|
resize = video_cfg.get('resize')
|
|
img_format = output_cfg.get('image_format', 'jpg')
|
|
img_quality = output_cfg.get('image_quality', 95)
|
|
|
|
# Process frames
|
|
frame_idx = start_frame
|
|
processed = 0
|
|
|
|
frames_to_process = min(
|
|
(end_frame - start_frame) // frame_interval,
|
|
max_frames or float('inf')
|
|
)
|
|
|
|
pbar = tqdm(total=int(frames_to_process), desc="Annotating")
|
|
|
|
while frame_idx < end_frame:
|
|
if max_frames and processed >= max_frames:
|
|
break
|
|
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
if (frame_idx - start_frame) % frame_interval == 0:
|
|
# Resize if needed
|
|
if resize:
|
|
frame = cv2.resize(frame, tuple(resize))
|
|
|
|
# Get current dimensions
|
|
h, w = frame.shape[:2]
|
|
|
|
# Run detection
|
|
detections = self.process_frame(
|
|
frame,
|
|
frame_id=frame_idx,
|
|
conf_threshold=conf_thresh,
|
|
iou_threshold=iou_thresh,
|
|
)
|
|
|
|
# Filter detections
|
|
detections = self.filter_detections(detections, w, h)
|
|
|
|
# Store results
|
|
result.detections_per_frame[frame_idx] = detections
|
|
result.total_detections += len(detections)
|
|
|
|
# Save outputs
|
|
frame_name = f"frame_{frame_idx:06d}"
|
|
|
|
if output_cfg.get('save_snapshots', True) and detections:
|
|
img_path = out_path / 'images' / f"{frame_name}.{img_format}"
|
|
if img_format == 'jpg':
|
|
cv2.imwrite(str(img_path), frame, [cv2.IMWRITE_JPEG_QUALITY, img_quality])
|
|
else:
|
|
cv2.imwrite(str(img_path), frame)
|
|
|
|
if output_cfg.get('save_labels', True) and detections:
|
|
label_path = out_path / 'labels' / f"{frame_name}.txt"
|
|
self._write_yolo_label(label_path, detections, w, h)
|
|
|
|
if output_cfg.get('save_debug', True) and detections:
|
|
debug_frame = self._draw_detections(frame.copy(), detections)
|
|
debug_path = out_path / 'debug' / f"{frame_name}.{img_format}"
|
|
cv2.imwrite(str(debug_path), debug_frame)
|
|
|
|
processed += 1
|
|
pbar.update(1)
|
|
|
|
frame_idx += 1
|
|
|
|
pbar.close()
|
|
cap.release()
|
|
|
|
result.processed_frames = processed
|
|
|
|
# Save manifest
|
|
if output_cfg.get('save_manifest', True):
|
|
self._save_manifest(result, out_path / 'manifest.json')
|
|
|
|
print(f"\nAnnotation complete!")
|
|
print(f" Processed frames: {processed}")
|
|
print(f" Total detections: {result.total_detections}")
|
|
print(f" Output: {out_path}")
|
|
|
|
return result
|
|
|
|
def process_frame(
|
|
self,
|
|
frame: np.ndarray,
|
|
frame_id: int = 0,
|
|
conf_threshold: float = 0.25,
|
|
iou_threshold: float = 0.45,
|
|
) -> List[Detection]:
|
|
"""
|
|
Process single frame and return detections.
|
|
|
|
Args:
|
|
frame: Input frame (BGR)
|
|
frame_id: Frame index
|
|
conf_threshold: Confidence threshold
|
|
iou_threshold: NMS IoU threshold
|
|
|
|
Returns:
|
|
List of Detection objects
|
|
"""
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
# Run inference
|
|
results = self.model.predict(
|
|
frame,
|
|
conf=conf_threshold,
|
|
iou=iou_threshold,
|
|
verbose=False,
|
|
)
|
|
|
|
detections = []
|
|
|
|
for result in results:
|
|
if result.boxes is None:
|
|
continue
|
|
|
|
for box in result.boxes:
|
|
class_id = int(box.cls[0].item())
|
|
confidence = float(box.conf[0].item())
|
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
|
|
|
class_name = self.class_names.get(class_id, str(class_id))
|
|
|
|
detection = Detection(
|
|
class_id=class_id,
|
|
class_name=class_name,
|
|
confidence=confidence,
|
|
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
|
frame_id=frame_id,
|
|
)
|
|
detections.append(detection)
|
|
|
|
return detections
|
|
|
|
def filter_detections(
|
|
self,
|
|
detections: List[Detection],
|
|
img_width: int,
|
|
img_height: int,
|
|
) -> List[Detection]:
|
|
"""
|
|
Apply filtering rules to detections.
|
|
|
|
Args:
|
|
detections: List of detections to filter
|
|
img_width: Image width
|
|
img_height: Image height
|
|
|
|
Returns:
|
|
Filtered list of detections
|
|
"""
|
|
det_cfg = self.config.get('detection', {})
|
|
|
|
allowed_classes = det_cfg.get('classes') # None = all
|
|
min_conf = det_cfg.get('min_confidence', 0.3)
|
|
min_area = det_cfg.get('min_area', 100)
|
|
max_area = det_cfg.get('max_area')
|
|
min_size = det_cfg.get('min_size', 0.01)
|
|
|
|
filtered = []
|
|
|
|
for det in detections:
|
|
# Filter by class
|
|
if allowed_classes is not None and det.class_id not in allowed_classes:
|
|
continue
|
|
|
|
# Filter by confidence
|
|
if det.confidence < min_conf:
|
|
continue
|
|
|
|
# Filter by area
|
|
area = det.bbox.area()
|
|
if area < min_area:
|
|
continue
|
|
if max_area is not None and area > max_area:
|
|
continue
|
|
|
|
# Filter by normalized size
|
|
norm_w = det.bbox.width / img_width
|
|
norm_h = det.bbox.height / img_height
|
|
if norm_w < min_size or norm_h < min_size:
|
|
continue
|
|
|
|
filtered.append(det)
|
|
|
|
return filtered
|
|
|
|
def _write_yolo_label(
|
|
self,
|
|
path: Path,
|
|
detections: List[Detection],
|
|
img_width: int,
|
|
img_height: int,
|
|
) -> None:
|
|
"""Write YOLO format label file."""
|
|
lines = []
|
|
for det in detections:
|
|
x_c, y_c, w, h = det.bbox.to_yolo(img_width, img_height)
|
|
line = f"{det.class_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}"
|
|
lines.append(line)
|
|
|
|
with open(path, 'w') as f:
|
|
f.write('\n'.join(lines))
|
|
|
|
def _draw_detections(
|
|
self,
|
|
frame: np.ndarray,
|
|
detections: List[Detection],
|
|
) -> np.ndarray:
|
|
"""Draw detection boxes on frame."""
|
|
for det in detections:
|
|
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
|
|
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
|
|
|
|
# Draw box
|
|
color = (0, 255, 0)
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
|
|
# Draw label
|
|
label = f"{det.class_name} {det.confidence:.2f}"
|
|
(label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
cv2.rectangle(frame, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1)
|
|
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
|
|
|
return frame
|
|
|
|
def _save_manifest(self, result: AnnotationResult, path: Path) -> None:
|
|
"""Save annotation manifest as JSON."""
|
|
import json
|
|
|
|
manifest = {
|
|
'created': result.created_at.isoformat(),
|
|
'video_path': result.video_path,
|
|
'total_frames': result.total_frames,
|
|
'processed_frames': result.processed_frames,
|
|
'total_detections': result.total_detections,
|
|
'frames': {},
|
|
}
|
|
|
|
for frame_id, detections in result.detections_per_frame.items():
|
|
manifest['frames'][str(frame_id)] = [
|
|
{
|
|
'class_id': d.class_id,
|
|
'class_name': d.class_name,
|
|
'confidence': round(d.confidence, 4),
|
|
'bbox': [d.bbox.x1, d.bbox.y1, d.bbox.x2, d.bbox.y2],
|
|
}
|
|
for d in detections
|
|
]
|
|
|
|
with open(path, 'w') as f:
|
|
json.dump(manifest, f, indent=2)
|