add sam2 yolo auto annotation
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
YOLO-Assisted Video Annotator
|
||||
|
||||
Auto-annotate videos using pretrained YOLOv9t model.
|
||||
Outputs clean snapshots paired with YOLO format labels.
|
||||
"""
|
||||
|
||||
from .annotator import YOLOAnnotator
|
||||
from .video_source import VideoSource
|
||||
from .export import AnnotationExporter
|
||||
from .visualizer import DebugVisualizer
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = [
|
||||
"YOLOAnnotator",
|
||||
"VideoSource",
|
||||
"AnnotationExporter",
|
||||
"DebugVisualizer",
|
||||
]
|
||||
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
Core YOLO-based video annotator.
|
||||
|
||||
Uses pretrained YOLOv9t to automatically detect and annotate objects in video frames.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import yaml
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
HAS_ULTRALYTICS = True
|
||||
except ImportError:
|
||||
HAS_ULTRALYTICS = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BBox:
|
||||
"""Bounding box representation."""
|
||||
x1: float # Top-left x (pixels)
|
||||
y1: float # Top-left y (pixels)
|
||||
x2: float # Bottom-right x (pixels)
|
||||
y2: float # Bottom-right y (pixels)
|
||||
|
||||
def to_yolo(self, img_width: int, img_height: int) -> Tuple[float, float, float, float]:
|
||||
"""Convert to YOLO format (x_center, y_center, width, height) normalized."""
|
||||
x_center = ((self.x1 + self.x2) / 2) / img_width
|
||||
y_center = ((self.y1 + self.y2) / 2) / img_height
|
||||
width = (self.x2 - self.x1) / img_width
|
||||
height = (self.y2 - self.y1) / img_height
|
||||
return (x_center, y_center, width, height)
|
||||
|
||||
def area(self) -> float:
|
||||
"""Calculate bbox area in pixels."""
|
||||
return (self.x2 - self.x1) * (self.y2 - self.y1)
|
||||
|
||||
@property
|
||||
def width(self) -> float:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self) -> float:
|
||||
return self.y2 - self.y1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""Single object detection."""
|
||||
class_id: int
|
||||
class_name: str
|
||||
confidence: float
|
||||
bbox: BBox
|
||||
frame_id: int
|
||||
timestamp: float = 0.0
|
||||
track_id: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationResult:
|
||||
"""Result of annotating a video."""
|
||||
video_path: str
|
||||
total_frames: int
|
||||
processed_frames: int
|
||||
total_detections: int
|
||||
detections_per_frame: Dict[int, List[Detection]] = field(default_factory=dict)
|
||||
output_dir: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class YOLOAnnotator:
|
||||
"""YOLO-based automatic video annotator."""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None, config: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize annotator.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML config file
|
||||
config: Config dictionary (overrides config_path)
|
||||
"""
|
||||
if not HAS_ULTRALYTICS:
|
||||
raise ImportError("ultralytics package required. Install with: pip install ultralytics")
|
||||
|
||||
self.config = self._load_config(config_path, config)
|
||||
self.model = None
|
||||
self.class_names = self.config.get('class_names', {})
|
||||
|
||||
def _load_config(self, config_path: Optional[str], config: Optional[Dict]) -> Dict:
|
||||
"""Load configuration from file or dict."""
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
if config_path is not None:
|
||||
with open(config_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
# Default config
|
||||
return {
|
||||
'model': {
|
||||
'path': 'yolov9t.pt',
|
||||
'device': 'cuda',
|
||||
'conf_threshold': 0.25,
|
||||
'iou_threshold': 0.45,
|
||||
},
|
||||
'video': {
|
||||
'sample_fps': 2,
|
||||
'max_frames': None,
|
||||
'start_time': 0,
|
||||
'end_time': None,
|
||||
'resize': None,
|
||||
},
|
||||
'detection': {
|
||||
'classes': None,
|
||||
'min_confidence': 0.3,
|
||||
'min_area': 100,
|
||||
'max_area': None,
|
||||
'min_size': 0.01,
|
||||
},
|
||||
'output': {
|
||||
'directory': 'output/annotations',
|
||||
'save_snapshots': True,
|
||||
'save_labels': True,
|
||||
'save_debug': True,
|
||||
'save_manifest': True,
|
||||
'image_format': 'jpg',
|
||||
'image_quality': 95,
|
||||
},
|
||||
}
|
||||
|
||||
def load_model(self, model_path: Optional[str] = None, device: Optional[str] = None) -> None:
|
||||
"""
|
||||
Load YOLOv9t model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights (overrides config)
|
||||
device: Device to use (overrides config)
|
||||
"""
|
||||
model_cfg = self.config.get('model', {})
|
||||
path = model_path or model_cfg.get('path', 'yolov9t.pt')
|
||||
dev = device or model_cfg.get('device', 'cuda')
|
||||
|
||||
print(f"Loading model: {path}")
|
||||
self.model = YOLO(path)
|
||||
self.model.to(dev)
|
||||
print(f"Model loaded on {dev}")
|
||||
|
||||
def process_video(
|
||||
self,
|
||||
video_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None
|
||||
) -> AnnotationResult:
|
||||
"""
|
||||
Process entire video and generate annotations.
|
||||
|
||||
Args:
|
||||
video_path: Path to video file (overrides config)
|
||||
output_dir: Output directory (overrides config)
|
||||
|
||||
Returns:
|
||||
AnnotationResult with all detections
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
video_cfg = self.config.get('video', {})
|
||||
output_cfg = self.config.get('output', {})
|
||||
|
||||
source = video_path or video_cfg.get('source')
|
||||
if source is None:
|
||||
raise ValueError("No video source specified")
|
||||
|
||||
out_dir = output_dir or output_cfg.get('directory', 'output/annotations')
|
||||
out_path = Path(out_dir)
|
||||
|
||||
# Create output directories
|
||||
if output_cfg.get('save_snapshots', True):
|
||||
(out_path / 'images').mkdir(parents=True, exist_ok=True)
|
||||
if output_cfg.get('save_labels', True):
|
||||
(out_path / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
if output_cfg.get('save_debug', True):
|
||||
(out_path / 'debug').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Open video
|
||||
cap = cv2.VideoCapture(source)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {source}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
# Calculate frame interval
|
||||
sample_fps = video_cfg.get('sample_fps')
|
||||
if sample_fps and sample_fps < fps:
|
||||
frame_interval = int(fps / sample_fps)
|
||||
else:
|
||||
frame_interval = 1
|
||||
|
||||
# Frame range
|
||||
start_time = video_cfg.get('start_time', 0)
|
||||
end_time = video_cfg.get('end_time')
|
||||
max_frames = video_cfg.get('max_frames')
|
||||
|
||||
start_frame = int(start_time * fps)
|
||||
end_frame = int(end_time * fps) if end_time else total_frames
|
||||
end_frame = min(end_frame, total_frames)
|
||||
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
# Initialize result
|
||||
result = AnnotationResult(
|
||||
video_path=source,
|
||||
total_frames=total_frames,
|
||||
processed_frames=0,
|
||||
total_detections=0,
|
||||
output_dir=str(out_path),
|
||||
)
|
||||
|
||||
# Get config values
|
||||
model_cfg = self.config.get('model', {})
|
||||
conf_thresh = model_cfg.get('conf_threshold', 0.25)
|
||||
iou_thresh = model_cfg.get('iou_threshold', 0.45)
|
||||
resize = video_cfg.get('resize')
|
||||
img_format = output_cfg.get('image_format', 'jpg')
|
||||
img_quality = output_cfg.get('image_quality', 95)
|
||||
|
||||
# Process frames
|
||||
frame_idx = start_frame
|
||||
processed = 0
|
||||
|
||||
frames_to_process = min(
|
||||
(end_frame - start_frame) // frame_interval,
|
||||
max_frames or float('inf')
|
||||
)
|
||||
|
||||
pbar = tqdm(total=int(frames_to_process), desc="Annotating")
|
||||
|
||||
while frame_idx < end_frame:
|
||||
if max_frames and processed >= max_frames:
|
||||
break
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if (frame_idx - start_frame) % frame_interval == 0:
|
||||
# Resize if needed
|
||||
if resize:
|
||||
frame = cv2.resize(frame, tuple(resize))
|
||||
|
||||
# Get current dimensions
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Run detection
|
||||
detections = self.process_frame(
|
||||
frame,
|
||||
frame_id=frame_idx,
|
||||
conf_threshold=conf_thresh,
|
||||
iou_threshold=iou_thresh,
|
||||
)
|
||||
|
||||
# Filter detections
|
||||
detections = self.filter_detections(detections, w, h)
|
||||
|
||||
# Store results
|
||||
result.detections_per_frame[frame_idx] = detections
|
||||
result.total_detections += len(detections)
|
||||
|
||||
# Save outputs
|
||||
frame_name = f"frame_{frame_idx:06d}"
|
||||
|
||||
if output_cfg.get('save_snapshots', True) and detections:
|
||||
img_path = out_path / 'images' / f"{frame_name}.{img_format}"
|
||||
if img_format == 'jpg':
|
||||
cv2.imwrite(str(img_path), frame, [cv2.IMWRITE_JPEG_QUALITY, img_quality])
|
||||
else:
|
||||
cv2.imwrite(str(img_path), frame)
|
||||
|
||||
if output_cfg.get('save_labels', True) and detections:
|
||||
label_path = out_path / 'labels' / f"{frame_name}.txt"
|
||||
self._write_yolo_label(label_path, detections, w, h)
|
||||
|
||||
if output_cfg.get('save_debug', True) and detections:
|
||||
debug_frame = self._draw_detections(frame.copy(), detections)
|
||||
debug_path = out_path / 'debug' / f"{frame_name}.{img_format}"
|
||||
cv2.imwrite(str(debug_path), debug_frame)
|
||||
|
||||
processed += 1
|
||||
pbar.update(1)
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
pbar.close()
|
||||
cap.release()
|
||||
|
||||
result.processed_frames = processed
|
||||
|
||||
# Save manifest
|
||||
if output_cfg.get('save_manifest', True):
|
||||
self._save_manifest(result, out_path / 'manifest.json')
|
||||
|
||||
print(f"\nAnnotation complete!")
|
||||
print(f" Processed frames: {processed}")
|
||||
print(f" Total detections: {result.total_detections}")
|
||||
print(f" Output: {out_path}")
|
||||
|
||||
return result
|
||||
|
||||
def process_frame(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
frame_id: int = 0,
|
||||
conf_threshold: float = 0.25,
|
||||
iou_threshold: float = 0.45,
|
||||
) -> List[Detection]:
|
||||
"""
|
||||
Process single frame and return detections.
|
||||
|
||||
Args:
|
||||
frame: Input frame (BGR)
|
||||
frame_id: Frame index
|
||||
conf_threshold: Confidence threshold
|
||||
iou_threshold: NMS IoU threshold
|
||||
|
||||
Returns:
|
||||
List of Detection objects
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
# Run inference
|
||||
results = self.model.predict(
|
||||
frame,
|
||||
conf=conf_threshold,
|
||||
iou=iou_threshold,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
detections = []
|
||||
|
||||
for result in results:
|
||||
if result.boxes is None:
|
||||
continue
|
||||
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].item())
|
||||
confidence = float(box.conf[0].item())
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
|
||||
class_name = self.class_names.get(class_id, str(class_id))
|
||||
|
||||
detection = Detection(
|
||||
class_id=class_id,
|
||||
class_name=class_name,
|
||||
confidence=confidence,
|
||||
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
||||
frame_id=frame_id,
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
def filter_detections(
|
||||
self,
|
||||
detections: List[Detection],
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
) -> List[Detection]:
|
||||
"""
|
||||
Apply filtering rules to detections.
|
||||
|
||||
Args:
|
||||
detections: List of detections to filter
|
||||
img_width: Image width
|
||||
img_height: Image height
|
||||
|
||||
Returns:
|
||||
Filtered list of detections
|
||||
"""
|
||||
det_cfg = self.config.get('detection', {})
|
||||
|
||||
allowed_classes = det_cfg.get('classes') # None = all
|
||||
min_conf = det_cfg.get('min_confidence', 0.3)
|
||||
min_area = det_cfg.get('min_area', 100)
|
||||
max_area = det_cfg.get('max_area')
|
||||
min_size = det_cfg.get('min_size', 0.01)
|
||||
|
||||
filtered = []
|
||||
|
||||
for det in detections:
|
||||
# Filter by class
|
||||
if allowed_classes is not None and det.class_id not in allowed_classes:
|
||||
continue
|
||||
|
||||
# Filter by confidence
|
||||
if det.confidence < min_conf:
|
||||
continue
|
||||
|
||||
# Filter by area
|
||||
area = det.bbox.area()
|
||||
if area < min_area:
|
||||
continue
|
||||
if max_area is not None and area > max_area:
|
||||
continue
|
||||
|
||||
# Filter by normalized size
|
||||
norm_w = det.bbox.width / img_width
|
||||
norm_h = det.bbox.height / img_height
|
||||
if norm_w < min_size or norm_h < min_size:
|
||||
continue
|
||||
|
||||
filtered.append(det)
|
||||
|
||||
return filtered
|
||||
|
||||
def _write_yolo_label(
|
||||
self,
|
||||
path: Path,
|
||||
detections: List[Detection],
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
) -> None:
|
||||
"""Write YOLO format label file."""
|
||||
lines = []
|
||||
for det in detections:
|
||||
x_c, y_c, w, h = det.bbox.to_yolo(img_width, img_height)
|
||||
line = f"{det.class_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}"
|
||||
lines.append(line)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('\n'.join(lines))
|
||||
|
||||
def _draw_detections(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
detections: List[Detection],
|
||||
) -> np.ndarray:
|
||||
"""Draw detection boxes on frame."""
|
||||
for det in detections:
|
||||
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
|
||||
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
|
||||
|
||||
# Draw box
|
||||
color = (0, 255, 0)
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# Draw label
|
||||
label = f"{det.class_name} {det.confidence:.2f}"
|
||||
(label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
cv2.rectangle(frame, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1)
|
||||
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||||
|
||||
return frame
|
||||
|
||||
def _save_manifest(self, result: AnnotationResult, path: Path) -> None:
|
||||
"""Save annotation manifest as JSON."""
|
||||
import json
|
||||
|
||||
manifest = {
|
||||
'created': result.created_at.isoformat(),
|
||||
'video_path': result.video_path,
|
||||
'total_frames': result.total_frames,
|
||||
'processed_frames': result.processed_frames,
|
||||
'total_detections': result.total_detections,
|
||||
'frames': {},
|
||||
}
|
||||
|
||||
for frame_id, detections in result.detections_per_frame.items():
|
||||
manifest['frames'][str(frame_id)] = [
|
||||
{
|
||||
'class_id': d.class_id,
|
||||
'class_name': d.class_name,
|
||||
'confidence': round(d.confidence, 4),
|
||||
'bbox': [d.bbox.x1, d.bbox.y1, d.bbox.x2, d.bbox.y2],
|
||||
}
|
||||
for d in detections
|
||||
]
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Annotation export utilities for YOLO format output.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPair:
|
||||
"""Snapshot and label pair."""
|
||||
image_path: str
|
||||
label_path: str
|
||||
frame_id: int
|
||||
timestamp: float
|
||||
detections: List[Any]
|
||||
camera_name: str = "default"
|
||||
|
||||
|
||||
class AnnotationExporter:
|
||||
"""Export annotations in YOLO format with snapshot pairs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
class_names: Optional[Dict[int, str]] = None,
|
||||
image_format: str = "jpg",
|
||||
image_quality: int = 95,
|
||||
):
|
||||
"""
|
||||
Initialize exporter.
|
||||
|
||||
Args:
|
||||
output_dir: Root output directory
|
||||
class_names: Mapping of class_id to name
|
||||
image_format: Output image format (jpg, png)
|
||||
image_quality: JPEG quality (1-100)
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.class_names = class_names or {}
|
||||
self.image_format = image_format
|
||||
self.image_quality = image_quality
|
||||
|
||||
# Create directories
|
||||
self.images_dir = self.output_dir / "images"
|
||||
self.labels_dir = self.output_dir / "labels"
|
||||
self.debug_dir = self.output_dir / "debug"
|
||||
|
||||
self.images_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.labels_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._pairs: List[AnnotationPair] = []
|
||||
|
||||
def save_pair(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
detections: List[Any],
|
||||
frame_id: int,
|
||||
timestamp: float = 0.0,
|
||||
camera_name: str = "default",
|
||||
save_debug: bool = False,
|
||||
) -> Optional[AnnotationPair]:
|
||||
"""
|
||||
Save snapshot and label pair.
|
||||
|
||||
Args:
|
||||
frame: Image frame (BGR)
|
||||
detections: List of Detection objects
|
||||
frame_id: Frame index
|
||||
timestamp: Frame timestamp
|
||||
camera_name: Camera/source name
|
||||
save_debug: Save debug visualization
|
||||
|
||||
Returns:
|
||||
AnnotationPair or None if no detections
|
||||
"""
|
||||
if not detections:
|
||||
return None
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Generate filename with timestamp
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
base_name = f"{camera_name}_{ts}_{frame_id:06d}"
|
||||
|
||||
# Save image
|
||||
image_path = self.images_dir / f"{base_name}.{self.image_format}"
|
||||
if self.image_format == "jpg":
|
||||
cv2.imwrite(str(image_path), frame, [cv2.IMWRITE_JPEG_QUALITY, self.image_quality])
|
||||
else:
|
||||
cv2.imwrite(str(image_path), frame)
|
||||
|
||||
# Save label
|
||||
label_path = self.labels_dir / f"{base_name}.txt"
|
||||
self._write_yolo_label(label_path, detections, w, h)
|
||||
|
||||
# Save debug if requested
|
||||
if save_debug:
|
||||
self.debug_dir.mkdir(exist_ok=True)
|
||||
debug_path = self.debug_dir / f"{base_name}_debug.{self.image_format}"
|
||||
debug_frame = self._draw_detections(frame.copy(), detections)
|
||||
cv2.imwrite(str(debug_path), debug_frame)
|
||||
|
||||
# Create pair record
|
||||
pair = AnnotationPair(
|
||||
image_path=str(image_path),
|
||||
label_path=str(label_path),
|
||||
frame_id=frame_id,
|
||||
timestamp=timestamp,
|
||||
detections=detections,
|
||||
camera_name=camera_name,
|
||||
)
|
||||
self._pairs.append(pair)
|
||||
|
||||
return pair
|
||||
|
||||
def _write_yolo_label(
|
||||
self,
|
||||
path: Path,
|
||||
detections: List[Any],
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
) -> None:
|
||||
"""Write YOLO format label file."""
|
||||
lines = []
|
||||
for det in detections:
|
||||
# Get bbox coordinates
|
||||
if hasattr(det, 'bbox'):
|
||||
x_c, y_c, w, h = det.bbox.to_yolo(img_width, img_height)
|
||||
class_id = det.class_id
|
||||
else:
|
||||
# Assume dict format
|
||||
bbox = det.get('bbox', [0, 0, 0, 0])
|
||||
class_id = det.get('class_id', 0)
|
||||
x1, y1, x2, y2 = bbox
|
||||
x_c = ((x1 + x2) / 2) / img_width
|
||||
y_c = ((y1 + y2) / 2) / img_height
|
||||
w = (x2 - x1) / img_width
|
||||
h = (y2 - y1) / img_height
|
||||
|
||||
line = f"{class_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}"
|
||||
lines.append(line)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('\n'.join(lines))
|
||||
|
||||
def _draw_detections(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
detections: List[Any],
|
||||
) -> np.ndarray:
|
||||
"""Draw detection boxes on frame."""
|
||||
for det in detections:
|
||||
if hasattr(det, 'bbox'):
|
||||
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
|
||||
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
|
||||
class_name = det.class_name
|
||||
conf = det.confidence
|
||||
else:
|
||||
bbox = det.get('bbox', [0, 0, 0, 0])
|
||||
x1, y1, x2, y2 = [int(x) for x in bbox]
|
||||
class_name = det.get('class_name', str(det.get('class_id', 0)))
|
||||
conf = det.get('confidence', 0)
|
||||
|
||||
color = (0, 255, 0)
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
label = f"{class_name} {conf:.2f}"
|
||||
(label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
cv2.rectangle(frame, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1)
|
||||
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||||
|
||||
return frame
|
||||
|
||||
def create_dataset_yaml(
|
||||
self,
|
||||
train_ratio: float = 0.8,
|
||||
shuffle: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Create YOLO data.yaml and split dataset.
|
||||
|
||||
Args:
|
||||
train_ratio: Ratio of training data
|
||||
shuffle: Shuffle data before splitting
|
||||
|
||||
Returns:
|
||||
Path to data.yaml
|
||||
"""
|
||||
import random
|
||||
import yaml
|
||||
|
||||
# Get all image-label pairs
|
||||
images = list(self.images_dir.glob(f"*.{self.image_format}"))
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(images)
|
||||
|
||||
# Split
|
||||
split_idx = int(len(images) * train_ratio)
|
||||
train_images = images[:split_idx]
|
||||
val_images = images[split_idx:]
|
||||
|
||||
# Create train/val directories
|
||||
train_img_dir = self.output_dir / "images" / "train"
|
||||
val_img_dir = self.output_dir / "images" / "val"
|
||||
train_lbl_dir = self.output_dir / "labels" / "train"
|
||||
val_lbl_dir = self.output_dir / "labels" / "val"
|
||||
|
||||
for d in [train_img_dir, val_img_dir, train_lbl_dir, val_lbl_dir]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Move files
|
||||
for img in train_images:
|
||||
lbl = self.labels_dir / f"{img.stem}.txt"
|
||||
shutil.move(str(img), str(train_img_dir / img.name))
|
||||
if lbl.exists():
|
||||
shutil.move(str(lbl), str(train_lbl_dir / lbl.name))
|
||||
|
||||
for img in val_images:
|
||||
lbl = self.labels_dir / f"{img.stem}.txt"
|
||||
shutil.move(str(img), str(val_img_dir / img.name))
|
||||
if lbl.exists():
|
||||
shutil.move(str(lbl), str(val_lbl_dir / lbl.name))
|
||||
|
||||
# Create data.yaml
|
||||
data_config = {
|
||||
'path': str(self.output_dir.absolute()),
|
||||
'train': 'images/train',
|
||||
'val': 'images/val',
|
||||
'names': self.class_names,
|
||||
'nc': len(self.class_names),
|
||||
}
|
||||
|
||||
yaml_path = self.output_dir / "data.yaml"
|
||||
with open(yaml_path, 'w') as f:
|
||||
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f"Dataset created:")
|
||||
print(f" Train: {len(train_images)} images")
|
||||
print(f" Val: {len(val_images)} images")
|
||||
print(f" Config: {yaml_path}")
|
||||
|
||||
return str(yaml_path)
|
||||
|
||||
def save_manifest(self) -> str:
|
||||
"""
|
||||
Save manifest of all annotation pairs.
|
||||
|
||||
Returns:
|
||||
Path to manifest file
|
||||
"""
|
||||
manifest = {
|
||||
'created': datetime.now().isoformat(),
|
||||
'total_pairs': len(self._pairs),
|
||||
'output_dir': str(self.output_dir),
|
||||
'pairs': [
|
||||
{
|
||||
'image': p.image_path,
|
||||
'label': p.label_path,
|
||||
'frame_id': p.frame_id,
|
||||
'timestamp': p.timestamp,
|
||||
'camera': p.camera_name,
|
||||
'num_detections': len(p.detections),
|
||||
}
|
||||
for p in self._pairs
|
||||
]
|
||||
}
|
||||
|
||||
manifest_path = self.output_dir / "manifest.json"
|
||||
with open(manifest_path, 'w') as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
return str(manifest_path)
|
||||
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
Video source handler for MP4 and RTSP streams.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Generator, Tuple, Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread, Lock
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoInfo:
|
||||
"""Video metadata."""
|
||||
path: str
|
||||
fps: float
|
||||
frame_count: int
|
||||
width: int
|
||||
height: int
|
||||
duration: float
|
||||
codec: str = ""
|
||||
|
||||
|
||||
class VideoSource:
|
||||
"""Video source handler supporting MP4 files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
fps_limit: Optional[float] = None,
|
||||
loop: bool = False,
|
||||
resize: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize video source.
|
||||
|
||||
Args:
|
||||
source: Path to MP4 file
|
||||
fps_limit: Maximum FPS to process
|
||||
loop: Loop video when finished
|
||||
resize: Resize frames to (width, height)
|
||||
"""
|
||||
self.source = source
|
||||
self.fps_limit = fps_limit
|
||||
self.loop = loop
|
||||
self.resize = resize
|
||||
|
||||
self.cap: Optional[cv2.VideoCapture] = None
|
||||
self.info: Optional[VideoInfo] = None
|
||||
self._frame_interval = 0
|
||||
self._is_running = False
|
||||
|
||||
def open(self) -> VideoInfo:
|
||||
"""Open video source and return info."""
|
||||
path = Path(self.source)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Video not found: {self.source}")
|
||||
|
||||
self.cap = cv2.VideoCapture(str(path))
|
||||
if not self.cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {self.source}")
|
||||
|
||||
fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
# Calculate frame interval for FPS limiting
|
||||
if self.fps_limit and self.fps_limit < fps:
|
||||
self._frame_interval = int(fps / self.fps_limit)
|
||||
else:
|
||||
self._frame_interval = 1
|
||||
|
||||
self.info = VideoInfo(
|
||||
path=str(path),
|
||||
fps=fps,
|
||||
frame_count=frame_count,
|
||||
width=width,
|
||||
height=height,
|
||||
duration=frame_count / fps if fps > 0 else 0,
|
||||
)
|
||||
|
||||
self._is_running = True
|
||||
return self.info
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close video source."""
|
||||
self._is_running = False
|
||||
if self.cap is not None:
|
||||
self.cap.release()
|
||||
self.cap = None
|
||||
|
||||
def read(self) -> Tuple[bool, Optional[np.ndarray], int]:
|
||||
"""
|
||||
Read next frame.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, frame, frame_index)
|
||||
"""
|
||||
if self.cap is None:
|
||||
self.open()
|
||||
|
||||
ret, frame = self.cap.read()
|
||||
frame_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) - 1
|
||||
|
||||
if not ret:
|
||||
if self.loop:
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
ret, frame = self.cap.read()
|
||||
frame_idx = 0
|
||||
else:
|
||||
return False, None, frame_idx
|
||||
|
||||
if ret and self.resize:
|
||||
frame = cv2.resize(frame, self.resize)
|
||||
|
||||
return ret, frame, frame_idx
|
||||
|
||||
def iterate(
|
||||
self,
|
||||
start_time: float = 0,
|
||||
end_time: Optional[float] = None,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> Generator[Tuple[int, np.ndarray, float], None, None]:
|
||||
"""
|
||||
Iterate through video frames.
|
||||
|
||||
Args:
|
||||
start_time: Start time in seconds
|
||||
end_time: End time in seconds
|
||||
max_frames: Maximum frames to yield
|
||||
|
||||
Yields:
|
||||
Tuple of (frame_index, frame, timestamp)
|
||||
"""
|
||||
if self.cap is None:
|
||||
self.open()
|
||||
|
||||
fps = self.info.fps
|
||||
start_frame = int(start_time * fps)
|
||||
end_frame = int(end_time * fps) if end_time else self.info.frame_count
|
||||
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
frame_idx = start_frame
|
||||
yielded = 0
|
||||
|
||||
while frame_idx < end_frame and self._is_running:
|
||||
if max_frames and yielded >= max_frames:
|
||||
break
|
||||
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
if self.loop:
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
frame_idx = start_frame
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
if (frame_idx - start_frame) % self._frame_interval == 0:
|
||||
if self.resize:
|
||||
frame = cv2.resize(frame, self.resize)
|
||||
|
||||
timestamp = frame_idx / fps
|
||||
yield frame_idx, frame, timestamp
|
||||
yielded += 1
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class BufferedVideoSource(VideoSource):
|
||||
"""Video source with frame buffering for smoother processing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
buffer_size: int = 10,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize buffered video source.
|
||||
|
||||
Args:
|
||||
source: Path to video
|
||||
buffer_size: Number of frames to buffer
|
||||
**kwargs: Additional VideoSource arguments
|
||||
"""
|
||||
super().__init__(source, **kwargs)
|
||||
self.buffer_size = buffer_size
|
||||
self._buffer: Queue = Queue(maxsize=buffer_size)
|
||||
self._thread: Optional[Thread] = None
|
||||
self._lock = Lock()
|
||||
|
||||
def _reader_thread(self) -> None:
|
||||
"""Background thread to read frames into buffer."""
|
||||
while self._is_running:
|
||||
if self._buffer.full():
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
with self._lock:
|
||||
if self.cap is None:
|
||||
break
|
||||
ret, frame = self.cap.read()
|
||||
frame_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) - 1
|
||||
|
||||
if not ret:
|
||||
if self.loop:
|
||||
with self._lock:
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
continue
|
||||
else:
|
||||
self._buffer.put((False, None, -1))
|
||||
break
|
||||
|
||||
if self.resize:
|
||||
frame = cv2.resize(frame, self.resize)
|
||||
|
||||
self._buffer.put((True, frame, frame_idx))
|
||||
|
||||
def open(self) -> VideoInfo:
|
||||
"""Open video and start buffer thread."""
|
||||
info = super().open()
|
||||
|
||||
# Start reader thread
|
||||
self._thread = Thread(target=self._reader_thread, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
return info
|
||||
|
||||
def close(self) -> None:
|
||||
"""Stop buffer thread and close video."""
|
||||
self._is_running = False
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=1.0)
|
||||
super().close()
|
||||
|
||||
def read(self) -> Tuple[bool, Optional[np.ndarray], int]:
|
||||
"""Read frame from buffer."""
|
||||
if not self._is_running:
|
||||
return False, None, -1
|
||||
|
||||
try:
|
||||
return self._buffer.get(timeout=1.0)
|
||||
except:
|
||||
return False, None, -1
|
||||
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Debug visualization utilities.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DebugVisualizer:
|
||||
"""Visualization tools for debugging detections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
box_color: Tuple[int, int, int] = (0, 255, 0),
|
||||
box_thickness: int = 2,
|
||||
font_scale: float = 0.5,
|
||||
show_confidence: bool = True,
|
||||
show_class: bool = True,
|
||||
show_fps: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize visualizer.
|
||||
|
||||
Args:
|
||||
box_color: BGR color for bounding boxes
|
||||
box_thickness: Line thickness for boxes
|
||||
font_scale: Font scale for labels
|
||||
show_confidence: Show confidence scores
|
||||
show_class: Show class names
|
||||
show_fps: Show FPS counter
|
||||
"""
|
||||
self.box_color = box_color
|
||||
self.box_thickness = box_thickness
|
||||
self.font_scale = font_scale
|
||||
self.show_confidence = show_confidence
|
||||
self.show_class = show_class
|
||||
self.show_fps = show_fps
|
||||
|
||||
self._fps_history: List[float] = []
|
||||
self._last_time = datetime.now()
|
||||
|
||||
def draw_detections(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
detections: List[Any],
|
||||
colors: Optional[Dict[int, Tuple[int, int, int]]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draw detection boxes on frame.
|
||||
|
||||
Args:
|
||||
frame: Input frame (BGR)
|
||||
detections: List of Detection objects
|
||||
colors: Optional class_id to color mapping
|
||||
|
||||
Returns:
|
||||
Frame with annotations drawn
|
||||
"""
|
||||
output = frame.copy()
|
||||
|
||||
for det in detections:
|
||||
# Get detection info
|
||||
if hasattr(det, 'bbox'):
|
||||
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
|
||||
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
|
||||
class_id = det.class_id
|
||||
class_name = det.class_name
|
||||
conf = det.confidence
|
||||
else:
|
||||
bbox = det.get('bbox', [0, 0, 0, 0])
|
||||
x1, y1, x2, y2 = [int(x) for x in bbox]
|
||||
class_id = det.get('class_id', 0)
|
||||
class_name = det.get('class_name', str(class_id))
|
||||
conf = det.get('confidence', 0)
|
||||
|
||||
# Get color
|
||||
if colors and class_id in colors:
|
||||
color = colors[class_id]
|
||||
else:
|
||||
color = self.box_color
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(output, (x1, y1), (x2, y2), color, self.box_thickness)
|
||||
|
||||
# Build label
|
||||
label_parts = []
|
||||
if self.show_class:
|
||||
label_parts.append(class_name)
|
||||
if self.show_confidence:
|
||||
label_parts.append(f"{conf:.2f}")
|
||||
|
||||
if label_parts:
|
||||
label = " ".join(label_parts)
|
||||
(label_w, label_h), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1
|
||||
)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(
|
||||
output,
|
||||
(x1, y1 - label_h - 4),
|
||||
(x1 + label_w + 4, y1),
|
||||
color,
|
||||
-1
|
||||
)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(
|
||||
output,
|
||||
label,
|
||||
(x1 + 2, y1 - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
self.font_scale,
|
||||
(0, 0, 0),
|
||||
1
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def draw_fps(self, frame: np.ndarray, fps: Optional[float] = None) -> np.ndarray:
|
||||
"""
|
||||
Draw FPS counter on frame.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
fps: FPS value (auto-calculated if None)
|
||||
|
||||
Returns:
|
||||
Frame with FPS drawn
|
||||
"""
|
||||
if fps is None:
|
||||
# Calculate FPS
|
||||
now = datetime.now()
|
||||
dt = (now - self._last_time).total_seconds()
|
||||
self._last_time = now
|
||||
|
||||
if dt > 0:
|
||||
current_fps = 1.0 / dt
|
||||
self._fps_history.append(current_fps)
|
||||
if len(self._fps_history) > 30:
|
||||
self._fps_history.pop(0)
|
||||
fps = sum(self._fps_history) / len(self._fps_history)
|
||||
else:
|
||||
fps = 0
|
||||
|
||||
output = frame.copy()
|
||||
fps_text = f"FPS: {fps:.1f}"
|
||||
|
||||
cv2.putText(
|
||||
output,
|
||||
fps_text,
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
(0, 255, 0),
|
||||
2
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def draw_stats(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
stats: Dict[str, Any],
|
||||
position: Tuple[int, int] = (10, 60),
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draw statistics on frame.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
stats: Dictionary of stats to display
|
||||
position: Starting position (x, y)
|
||||
|
||||
Returns:
|
||||
Frame with stats drawn
|
||||
"""
|
||||
output = frame.copy()
|
||||
x, y = position
|
||||
line_height = 25
|
||||
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, float):
|
||||
text = f"{key}: {value:.2f}"
|
||||
else:
|
||||
text = f"{key}: {value}"
|
||||
|
||||
cv2.putText(
|
||||
output,
|
||||
text,
|
||||
(x, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2
|
||||
)
|
||||
cv2.putText(
|
||||
output,
|
||||
text,
|
||||
(x, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(0, 0, 0),
|
||||
1
|
||||
)
|
||||
|
||||
y += line_height
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ObjectListDisplay:
|
||||
"""Display detected objects in console/log format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
show_confidence: bool = True,
|
||||
show_class: bool = True,
|
||||
show_bbox: bool = True,
|
||||
show_track_id: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize object list display.
|
||||
|
||||
Args:
|
||||
show_confidence: Show confidence scores
|
||||
show_class: Show class names
|
||||
show_bbox: Show bounding box coordinates
|
||||
show_track_id: Show tracking ID
|
||||
"""
|
||||
self.show_confidence = show_confidence
|
||||
self.show_class = show_class
|
||||
self.show_bbox = show_bbox
|
||||
self.show_track_id = show_track_id
|
||||
|
||||
def format_detection(self, detection: Any) -> str:
|
||||
"""
|
||||
Format single detection as string.
|
||||
|
||||
Args:
|
||||
detection: Detection object
|
||||
|
||||
Returns:
|
||||
Formatted string
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if hasattr(detection, 'bbox'):
|
||||
class_name = detection.class_name
|
||||
conf = detection.confidence
|
||||
bbox = detection.bbox
|
||||
track_id = getattr(detection, 'track_id', None)
|
||||
x1, y1, x2, y2 = bbox.x1, bbox.y1, bbox.x2, bbox.y2
|
||||
else:
|
||||
class_name = detection.get('class_name', str(detection.get('class_id', 0)))
|
||||
conf = detection.get('confidence', 0)
|
||||
bbox = detection.get('bbox', [0, 0, 0, 0])
|
||||
track_id = detection.get('track_id')
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
if self.show_class:
|
||||
parts.append(f"{class_name}")
|
||||
|
||||
if self.show_confidence:
|
||||
parts.append(f"conf={conf:.2f}")
|
||||
|
||||
if self.show_bbox:
|
||||
parts.append(f"bbox=[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]")
|
||||
|
||||
if self.show_track_id and track_id is not None:
|
||||
parts.append(f"id={track_id}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
def format_list(self, detections: List[Any], frame_id: int = 0) -> str:
|
||||
"""
|
||||
Format list of detections.
|
||||
|
||||
Args:
|
||||
detections: List of Detection objects
|
||||
frame_id: Frame index
|
||||
|
||||
Returns:
|
||||
Formatted multi-line string
|
||||
"""
|
||||
lines = [f"Frame {frame_id}: {len(detections)} objects"]
|
||||
|
||||
for i, det in enumerate(detections):
|
||||
lines.append(f" [{i}] {self.format_detection(det)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def print_list(self, detections: List[Any], frame_id: int = 0) -> None:
|
||||
"""Print formatted detection list to console."""
|
||||
print(self.format_list(detections, frame_id))
|
||||
Reference in New Issue
Block a user