add sam2 yolo auto annotation

This commit is contained in:
2026-02-04 15:29:36 +07:00
parent 7e56948ece
commit 5a951d8812
2061 changed files with 316473 additions and 0 deletions
+19
View File
@@ -0,0 +1,19 @@
"""
YOLO-Assisted Video Annotator
Auto-annotate videos using pretrained YOLOv9t model.
Outputs clean snapshots paired with YOLO format labels.
"""
from .annotator import YOLOAnnotator
from .video_source import VideoSource
from .export import AnnotationExporter
from .visualizer import DebugVisualizer
__version__ = "0.1.0"
__all__ = [
"YOLOAnnotator",
"VideoSource",
"AnnotationExporter",
"DebugVisualizer",
]
+487
View File
@@ -0,0 +1,487 @@
"""
Core YOLO-based video annotator.
Uses pretrained YOLOv9t to automatically detect and annotate objects in video frames.
"""
import cv2
import yaml
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from tqdm import tqdm
try:
from ultralytics import YOLO
HAS_ULTRALYTICS = True
except ImportError:
HAS_ULTRALYTICS = False
@dataclass
class BBox:
"""Bounding box representation."""
x1: float # Top-left x (pixels)
y1: float # Top-left y (pixels)
x2: float # Bottom-right x (pixels)
y2: float # Bottom-right y (pixels)
def to_yolo(self, img_width: int, img_height: int) -> Tuple[float, float, float, float]:
"""Convert to YOLO format (x_center, y_center, width, height) normalized."""
x_center = ((self.x1 + self.x2) / 2) / img_width
y_center = ((self.y1 + self.y2) / 2) / img_height
width = (self.x2 - self.x1) / img_width
height = (self.y2 - self.y1) / img_height
return (x_center, y_center, width, height)
def area(self) -> float:
"""Calculate bbox area in pixels."""
return (self.x2 - self.x1) * (self.y2 - self.y1)
@property
def width(self) -> float:
return self.x2 - self.x1
@property
def height(self) -> float:
return self.y2 - self.y1
@dataclass
class Detection:
"""Single object detection."""
class_id: int
class_name: str
confidence: float
bbox: BBox
frame_id: int
timestamp: float = 0.0
track_id: Optional[int] = None
@dataclass
class AnnotationResult:
"""Result of annotating a video."""
video_path: str
total_frames: int
processed_frames: int
total_detections: int
detections_per_frame: Dict[int, List[Detection]] = field(default_factory=dict)
output_dir: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
class YOLOAnnotator:
"""YOLO-based automatic video annotator."""
def __init__(self, config_path: Optional[str] = None, config: Optional[Dict] = None):
"""
Initialize annotator.
Args:
config_path: Path to YAML config file
config: Config dictionary (overrides config_path)
"""
if not HAS_ULTRALYTICS:
raise ImportError("ultralytics package required. Install with: pip install ultralytics")
self.config = self._load_config(config_path, config)
self.model = None
self.class_names = self.config.get('class_names', {})
def _load_config(self, config_path: Optional[str], config: Optional[Dict]) -> Dict:
"""Load configuration from file or dict."""
if config is not None:
return config
if config_path is not None:
with open(config_path, 'r') as f:
return yaml.safe_load(f)
# Default config
return {
'model': {
'path': 'yolov9t.pt',
'device': 'cuda',
'conf_threshold': 0.25,
'iou_threshold': 0.45,
},
'video': {
'sample_fps': 2,
'max_frames': None,
'start_time': 0,
'end_time': None,
'resize': None,
},
'detection': {
'classes': None,
'min_confidence': 0.3,
'min_area': 100,
'max_area': None,
'min_size': 0.01,
},
'output': {
'directory': 'output/annotations',
'save_snapshots': True,
'save_labels': True,
'save_debug': True,
'save_manifest': True,
'image_format': 'jpg',
'image_quality': 95,
},
}
def load_model(self, model_path: Optional[str] = None, device: Optional[str] = None) -> None:
"""
Load YOLOv9t model.
Args:
model_path: Path to model weights (overrides config)
device: Device to use (overrides config)
"""
model_cfg = self.config.get('model', {})
path = model_path or model_cfg.get('path', 'yolov9t.pt')
dev = device or model_cfg.get('device', 'cuda')
print(f"Loading model: {path}")
self.model = YOLO(path)
self.model.to(dev)
print(f"Model loaded on {dev}")
def process_video(
self,
video_path: Optional[str] = None,
output_dir: Optional[str] = None
) -> AnnotationResult:
"""
Process entire video and generate annotations.
Args:
video_path: Path to video file (overrides config)
output_dir: Output directory (overrides config)
Returns:
AnnotationResult with all detections
"""
if self.model is None:
self.load_model()
video_cfg = self.config.get('video', {})
output_cfg = self.config.get('output', {})
source = video_path or video_cfg.get('source')
if source is None:
raise ValueError("No video source specified")
out_dir = output_dir or output_cfg.get('directory', 'output/annotations')
out_path = Path(out_dir)
# Create output directories
if output_cfg.get('save_snapshots', True):
(out_path / 'images').mkdir(parents=True, exist_ok=True)
if output_cfg.get('save_labels', True):
(out_path / 'labels').mkdir(parents=True, exist_ok=True)
if output_cfg.get('save_debug', True):
(out_path / 'debug').mkdir(parents=True, exist_ok=True)
# Open video
cap = cv2.VideoCapture(source)
if not cap.isOpened():
raise ValueError(f"Cannot open video: {source}")
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Calculate frame interval
sample_fps = video_cfg.get('sample_fps')
if sample_fps and sample_fps < fps:
frame_interval = int(fps / sample_fps)
else:
frame_interval = 1
# Frame range
start_time = video_cfg.get('start_time', 0)
end_time = video_cfg.get('end_time')
max_frames = video_cfg.get('max_frames')
start_frame = int(start_time * fps)
end_frame = int(end_time * fps) if end_time else total_frames
end_frame = min(end_frame, total_frames)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
# Initialize result
result = AnnotationResult(
video_path=source,
total_frames=total_frames,
processed_frames=0,
total_detections=0,
output_dir=str(out_path),
)
# Get config values
model_cfg = self.config.get('model', {})
conf_thresh = model_cfg.get('conf_threshold', 0.25)
iou_thresh = model_cfg.get('iou_threshold', 0.45)
resize = video_cfg.get('resize')
img_format = output_cfg.get('image_format', 'jpg')
img_quality = output_cfg.get('image_quality', 95)
# Process frames
frame_idx = start_frame
processed = 0
frames_to_process = min(
(end_frame - start_frame) // frame_interval,
max_frames or float('inf')
)
pbar = tqdm(total=int(frames_to_process), desc="Annotating")
while frame_idx < end_frame:
if max_frames and processed >= max_frames:
break
ret, frame = cap.read()
if not ret:
break
if (frame_idx - start_frame) % frame_interval == 0:
# Resize if needed
if resize:
frame = cv2.resize(frame, tuple(resize))
# Get current dimensions
h, w = frame.shape[:2]
# Run detection
detections = self.process_frame(
frame,
frame_id=frame_idx,
conf_threshold=conf_thresh,
iou_threshold=iou_thresh,
)
# Filter detections
detections = self.filter_detections(detections, w, h)
# Store results
result.detections_per_frame[frame_idx] = detections
result.total_detections += len(detections)
# Save outputs
frame_name = f"frame_{frame_idx:06d}"
if output_cfg.get('save_snapshots', True) and detections:
img_path = out_path / 'images' / f"{frame_name}.{img_format}"
if img_format == 'jpg':
cv2.imwrite(str(img_path), frame, [cv2.IMWRITE_JPEG_QUALITY, img_quality])
else:
cv2.imwrite(str(img_path), frame)
if output_cfg.get('save_labels', True) and detections:
label_path = out_path / 'labels' / f"{frame_name}.txt"
self._write_yolo_label(label_path, detections, w, h)
if output_cfg.get('save_debug', True) and detections:
debug_frame = self._draw_detections(frame.copy(), detections)
debug_path = out_path / 'debug' / f"{frame_name}.{img_format}"
cv2.imwrite(str(debug_path), debug_frame)
processed += 1
pbar.update(1)
frame_idx += 1
pbar.close()
cap.release()
result.processed_frames = processed
# Save manifest
if output_cfg.get('save_manifest', True):
self._save_manifest(result, out_path / 'manifest.json')
print(f"\nAnnotation complete!")
print(f" Processed frames: {processed}")
print(f" Total detections: {result.total_detections}")
print(f" Output: {out_path}")
return result
def process_frame(
self,
frame: np.ndarray,
frame_id: int = 0,
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
) -> List[Detection]:
"""
Process single frame and return detections.
Args:
frame: Input frame (BGR)
frame_id: Frame index
conf_threshold: Confidence threshold
iou_threshold: NMS IoU threshold
Returns:
List of Detection objects
"""
if self.model is None:
self.load_model()
# Run inference
results = self.model.predict(
frame,
conf=conf_threshold,
iou=iou_threshold,
verbose=False,
)
detections = []
for result in results:
if result.boxes is None:
continue
for box in result.boxes:
class_id = int(box.cls[0].item())
confidence = float(box.conf[0].item())
x1, y1, x2, y2 = box.xyxy[0].tolist()
class_name = self.class_names.get(class_id, str(class_id))
detection = Detection(
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
frame_id=frame_id,
)
detections.append(detection)
return detections
def filter_detections(
self,
detections: List[Detection],
img_width: int,
img_height: int,
) -> List[Detection]:
"""
Apply filtering rules to detections.
Args:
detections: List of detections to filter
img_width: Image width
img_height: Image height
Returns:
Filtered list of detections
"""
det_cfg = self.config.get('detection', {})
allowed_classes = det_cfg.get('classes') # None = all
min_conf = det_cfg.get('min_confidence', 0.3)
min_area = det_cfg.get('min_area', 100)
max_area = det_cfg.get('max_area')
min_size = det_cfg.get('min_size', 0.01)
filtered = []
for det in detections:
# Filter by class
if allowed_classes is not None and det.class_id not in allowed_classes:
continue
# Filter by confidence
if det.confidence < min_conf:
continue
# Filter by area
area = det.bbox.area()
if area < min_area:
continue
if max_area is not None and area > max_area:
continue
# Filter by normalized size
norm_w = det.bbox.width / img_width
norm_h = det.bbox.height / img_height
if norm_w < min_size or norm_h < min_size:
continue
filtered.append(det)
return filtered
def _write_yolo_label(
self,
path: Path,
detections: List[Detection],
img_width: int,
img_height: int,
) -> None:
"""Write YOLO format label file."""
lines = []
for det in detections:
x_c, y_c, w, h = det.bbox.to_yolo(img_width, img_height)
line = f"{det.class_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}"
lines.append(line)
with open(path, 'w') as f:
f.write('\n'.join(lines))
def _draw_detections(
self,
frame: np.ndarray,
detections: List[Detection],
) -> np.ndarray:
"""Draw detection boxes on frame."""
for det in detections:
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
# Draw box
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
# Draw label
label = f"{det.class_name} {det.confidence:.2f}"
(label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1)
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
return frame
def _save_manifest(self, result: AnnotationResult, path: Path) -> None:
"""Save annotation manifest as JSON."""
import json
manifest = {
'created': result.created_at.isoformat(),
'video_path': result.video_path,
'total_frames': result.total_frames,
'processed_frames': result.processed_frames,
'total_detections': result.total_detections,
'frames': {},
}
for frame_id, detections in result.detections_per_frame.items():
manifest['frames'][str(frame_id)] = [
{
'class_id': d.class_id,
'class_name': d.class_name,
'confidence': round(d.confidence, 4),
'bbox': [d.bbox.x1, d.bbox.y1, d.bbox.x2, d.bbox.y2],
}
for d in detections
]
with open(path, 'w') as f:
json.dump(manifest, f, indent=2)
+280
View File
@@ -0,0 +1,280 @@
"""
Annotation export utilities for YOLO format output.
"""
import cv2
import json
import shutil
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
import numpy as np
@dataclass
class AnnotationPair:
"""Snapshot and label pair."""
image_path: str
label_path: str
frame_id: int
timestamp: float
detections: List[Any]
camera_name: str = "default"
class AnnotationExporter:
"""Export annotations in YOLO format with snapshot pairs."""
def __init__(
self,
output_dir: str,
class_names: Optional[Dict[int, str]] = None,
image_format: str = "jpg",
image_quality: int = 95,
):
"""
Initialize exporter.
Args:
output_dir: Root output directory
class_names: Mapping of class_id to name
image_format: Output image format (jpg, png)
image_quality: JPEG quality (1-100)
"""
self.output_dir = Path(output_dir)
self.class_names = class_names or {}
self.image_format = image_format
self.image_quality = image_quality
# Create directories
self.images_dir = self.output_dir / "images"
self.labels_dir = self.output_dir / "labels"
self.debug_dir = self.output_dir / "debug"
self.images_dir.mkdir(parents=True, exist_ok=True)
self.labels_dir.mkdir(parents=True, exist_ok=True)
self._pairs: List[AnnotationPair] = []
def save_pair(
self,
frame: np.ndarray,
detections: List[Any],
frame_id: int,
timestamp: float = 0.0,
camera_name: str = "default",
save_debug: bool = False,
) -> Optional[AnnotationPair]:
"""
Save snapshot and label pair.
Args:
frame: Image frame (BGR)
detections: List of Detection objects
frame_id: Frame index
timestamp: Frame timestamp
camera_name: Camera/source name
save_debug: Save debug visualization
Returns:
AnnotationPair or None if no detections
"""
if not detections:
return None
h, w = frame.shape[:2]
# Generate filename with timestamp
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f"{camera_name}_{ts}_{frame_id:06d}"
# Save image
image_path = self.images_dir / f"{base_name}.{self.image_format}"
if self.image_format == "jpg":
cv2.imwrite(str(image_path), frame, [cv2.IMWRITE_JPEG_QUALITY, self.image_quality])
else:
cv2.imwrite(str(image_path), frame)
# Save label
label_path = self.labels_dir / f"{base_name}.txt"
self._write_yolo_label(label_path, detections, w, h)
# Save debug if requested
if save_debug:
self.debug_dir.mkdir(exist_ok=True)
debug_path = self.debug_dir / f"{base_name}_debug.{self.image_format}"
debug_frame = self._draw_detections(frame.copy(), detections)
cv2.imwrite(str(debug_path), debug_frame)
# Create pair record
pair = AnnotationPair(
image_path=str(image_path),
label_path=str(label_path),
frame_id=frame_id,
timestamp=timestamp,
detections=detections,
camera_name=camera_name,
)
self._pairs.append(pair)
return pair
def _write_yolo_label(
self,
path: Path,
detections: List[Any],
img_width: int,
img_height: int,
) -> None:
"""Write YOLO format label file."""
lines = []
for det in detections:
# Get bbox coordinates
if hasattr(det, 'bbox'):
x_c, y_c, w, h = det.bbox.to_yolo(img_width, img_height)
class_id = det.class_id
else:
# Assume dict format
bbox = det.get('bbox', [0, 0, 0, 0])
class_id = det.get('class_id', 0)
x1, y1, x2, y2 = bbox
x_c = ((x1 + x2) / 2) / img_width
y_c = ((y1 + y2) / 2) / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
line = f"{class_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}"
lines.append(line)
with open(path, 'w') as f:
f.write('\n'.join(lines))
def _draw_detections(
self,
frame: np.ndarray,
detections: List[Any],
) -> np.ndarray:
"""Draw detection boxes on frame."""
for det in detections:
if hasattr(det, 'bbox'):
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
class_name = det.class_name
conf = det.confidence
else:
bbox = det.get('bbox', [0, 0, 0, 0])
x1, y1, x2, y2 = [int(x) for x in bbox]
class_name = det.get('class_name', str(det.get('class_id', 0)))
conf = det.get('confidence', 0)
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"{class_name} {conf:.2f}"
(label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1)
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
return frame
def create_dataset_yaml(
self,
train_ratio: float = 0.8,
shuffle: bool = True,
) -> str:
"""
Create YOLO data.yaml and split dataset.
Args:
train_ratio: Ratio of training data
shuffle: Shuffle data before splitting
Returns:
Path to data.yaml
"""
import random
import yaml
# Get all image-label pairs
images = list(self.images_dir.glob(f"*.{self.image_format}"))
if shuffle:
random.shuffle(images)
# Split
split_idx = int(len(images) * train_ratio)
train_images = images[:split_idx]
val_images = images[split_idx:]
# Create train/val directories
train_img_dir = self.output_dir / "images" / "train"
val_img_dir = self.output_dir / "images" / "val"
train_lbl_dir = self.output_dir / "labels" / "train"
val_lbl_dir = self.output_dir / "labels" / "val"
for d in [train_img_dir, val_img_dir, train_lbl_dir, val_lbl_dir]:
d.mkdir(parents=True, exist_ok=True)
# Move files
for img in train_images:
lbl = self.labels_dir / f"{img.stem}.txt"
shutil.move(str(img), str(train_img_dir / img.name))
if lbl.exists():
shutil.move(str(lbl), str(train_lbl_dir / lbl.name))
for img in val_images:
lbl = self.labels_dir / f"{img.stem}.txt"
shutil.move(str(img), str(val_img_dir / img.name))
if lbl.exists():
shutil.move(str(lbl), str(val_lbl_dir / lbl.name))
# Create data.yaml
data_config = {
'path': str(self.output_dir.absolute()),
'train': 'images/train',
'val': 'images/val',
'names': self.class_names,
'nc': len(self.class_names),
}
yaml_path = self.output_dir / "data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
print(f"Dataset created:")
print(f" Train: {len(train_images)} images")
print(f" Val: {len(val_images)} images")
print(f" Config: {yaml_path}")
return str(yaml_path)
def save_manifest(self) -> str:
"""
Save manifest of all annotation pairs.
Returns:
Path to manifest file
"""
manifest = {
'created': datetime.now().isoformat(),
'total_pairs': len(self._pairs),
'output_dir': str(self.output_dir),
'pairs': [
{
'image': p.image_path,
'label': p.label_path,
'frame_id': p.frame_id,
'timestamp': p.timestamp,
'camera': p.camera_name,
'num_detections': len(p.detections),
}
for p in self._pairs
]
}
manifest_path = self.output_dir / "manifest.json"
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return str(manifest_path)
+257
View File
@@ -0,0 +1,257 @@
"""
Video source handler for MP4 and RTSP streams.
"""
import cv2
import numpy as np
from pathlib import Path
from typing import Generator, Tuple, Optional, Dict, Any
from dataclasses import dataclass
from threading import Thread, Lock
from queue import Queue
import time
@dataclass
class VideoInfo:
"""Video metadata."""
path: str
fps: float
frame_count: int
width: int
height: int
duration: float
codec: str = ""
class VideoSource:
"""Video source handler supporting MP4 files."""
def __init__(
self,
source: str,
fps_limit: Optional[float] = None,
loop: bool = False,
resize: Optional[Tuple[int, int]] = None,
):
"""
Initialize video source.
Args:
source: Path to MP4 file
fps_limit: Maximum FPS to process
loop: Loop video when finished
resize: Resize frames to (width, height)
"""
self.source = source
self.fps_limit = fps_limit
self.loop = loop
self.resize = resize
self.cap: Optional[cv2.VideoCapture] = None
self.info: Optional[VideoInfo] = None
self._frame_interval = 0
self._is_running = False
def open(self) -> VideoInfo:
"""Open video source and return info."""
path = Path(self.source)
if not path.exists():
raise FileNotFoundError(f"Video not found: {self.source}")
self.cap = cv2.VideoCapture(str(path))
if not self.cap.isOpened():
raise ValueError(f"Cannot open video: {self.source}")
fps = self.cap.get(cv2.CAP_PROP_FPS)
frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Calculate frame interval for FPS limiting
if self.fps_limit and self.fps_limit < fps:
self._frame_interval = int(fps / self.fps_limit)
else:
self._frame_interval = 1
self.info = VideoInfo(
path=str(path),
fps=fps,
frame_count=frame_count,
width=width,
height=height,
duration=frame_count / fps if fps > 0 else 0,
)
self._is_running = True
return self.info
def close(self) -> None:
"""Close video source."""
self._is_running = False
if self.cap is not None:
self.cap.release()
self.cap = None
def read(self) -> Tuple[bool, Optional[np.ndarray], int]:
"""
Read next frame.
Returns:
Tuple of (success, frame, frame_index)
"""
if self.cap is None:
self.open()
ret, frame = self.cap.read()
frame_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) - 1
if not ret:
if self.loop:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = self.cap.read()
frame_idx = 0
else:
return False, None, frame_idx
if ret and self.resize:
frame = cv2.resize(frame, self.resize)
return ret, frame, frame_idx
def iterate(
self,
start_time: float = 0,
end_time: Optional[float] = None,
max_frames: Optional[int] = None,
) -> Generator[Tuple[int, np.ndarray, float], None, None]:
"""
Iterate through video frames.
Args:
start_time: Start time in seconds
end_time: End time in seconds
max_frames: Maximum frames to yield
Yields:
Tuple of (frame_index, frame, timestamp)
"""
if self.cap is None:
self.open()
fps = self.info.fps
start_frame = int(start_time * fps)
end_frame = int(end_time * fps) if end_time else self.info.frame_count
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
frame_idx = start_frame
yielded = 0
while frame_idx < end_frame and self._is_running:
if max_frames and yielded >= max_frames:
break
ret, frame = self.cap.read()
if not ret:
if self.loop:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
frame_idx = start_frame
continue
else:
break
if (frame_idx - start_frame) % self._frame_interval == 0:
if self.resize:
frame = cv2.resize(frame, self.resize)
timestamp = frame_idx / fps
yield frame_idx, frame, timestamp
yielded += 1
frame_idx += 1
def __enter__(self):
self.open()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class BufferedVideoSource(VideoSource):
"""Video source with frame buffering for smoother processing."""
def __init__(
self,
source: str,
buffer_size: int = 10,
**kwargs,
):
"""
Initialize buffered video source.
Args:
source: Path to video
buffer_size: Number of frames to buffer
**kwargs: Additional VideoSource arguments
"""
super().__init__(source, **kwargs)
self.buffer_size = buffer_size
self._buffer: Queue = Queue(maxsize=buffer_size)
self._thread: Optional[Thread] = None
self._lock = Lock()
def _reader_thread(self) -> None:
"""Background thread to read frames into buffer."""
while self._is_running:
if self._buffer.full():
time.sleep(0.001)
continue
with self._lock:
if self.cap is None:
break
ret, frame = self.cap.read()
frame_idx = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) - 1
if not ret:
if self.loop:
with self._lock:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
else:
self._buffer.put((False, None, -1))
break
if self.resize:
frame = cv2.resize(frame, self.resize)
self._buffer.put((True, frame, frame_idx))
def open(self) -> VideoInfo:
"""Open video and start buffer thread."""
info = super().open()
# Start reader thread
self._thread = Thread(target=self._reader_thread, daemon=True)
self._thread.start()
return info
def close(self) -> None:
"""Stop buffer thread and close video."""
self._is_running = False
if self._thread is not None:
self._thread.join(timeout=1.0)
super().close()
def read(self) -> Tuple[bool, Optional[np.ndarray], int]:
"""Read frame from buffer."""
if not self._is_running:
return False, None, -1
try:
return self._buffer.get(timeout=1.0)
except:
return False, None, -1
+297
View File
@@ -0,0 +1,297 @@
"""
Debug visualization utilities.
"""
import cv2
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
class DebugVisualizer:
"""Visualization tools for debugging detections."""
def __init__(
self,
box_color: Tuple[int, int, int] = (0, 255, 0),
box_thickness: int = 2,
font_scale: float = 0.5,
show_confidence: bool = True,
show_class: bool = True,
show_fps: bool = True,
):
"""
Initialize visualizer.
Args:
box_color: BGR color for bounding boxes
box_thickness: Line thickness for boxes
font_scale: Font scale for labels
show_confidence: Show confidence scores
show_class: Show class names
show_fps: Show FPS counter
"""
self.box_color = box_color
self.box_thickness = box_thickness
self.font_scale = font_scale
self.show_confidence = show_confidence
self.show_class = show_class
self.show_fps = show_fps
self._fps_history: List[float] = []
self._last_time = datetime.now()
def draw_detections(
self,
frame: np.ndarray,
detections: List[Any],
colors: Optional[Dict[int, Tuple[int, int, int]]] = None,
) -> np.ndarray:
"""
Draw detection boxes on frame.
Args:
frame: Input frame (BGR)
detections: List of Detection objects
colors: Optional class_id to color mapping
Returns:
Frame with annotations drawn
"""
output = frame.copy()
for det in detections:
# Get detection info
if hasattr(det, 'bbox'):
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
class_id = det.class_id
class_name = det.class_name
conf = det.confidence
else:
bbox = det.get('bbox', [0, 0, 0, 0])
x1, y1, x2, y2 = [int(x) for x in bbox]
class_id = det.get('class_id', 0)
class_name = det.get('class_name', str(class_id))
conf = det.get('confidence', 0)
# Get color
if colors and class_id in colors:
color = colors[class_id]
else:
color = self.box_color
# Draw box
cv2.rectangle(output, (x1, y1), (x2, y2), color, self.box_thickness)
# Build label
label_parts = []
if self.show_class:
label_parts.append(class_name)
if self.show_confidence:
label_parts.append(f"{conf:.2f}")
if label_parts:
label = " ".join(label_parts)
(label_w, label_h), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1
)
# Draw label background
cv2.rectangle(
output,
(x1, y1 - label_h - 4),
(x1 + label_w + 4, y1),
color,
-1
)
# Draw label text
cv2.putText(
output,
label,
(x1 + 2, y1 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
self.font_scale,
(0, 0, 0),
1
)
return output
def draw_fps(self, frame: np.ndarray, fps: Optional[float] = None) -> np.ndarray:
"""
Draw FPS counter on frame.
Args:
frame: Input frame
fps: FPS value (auto-calculated if None)
Returns:
Frame with FPS drawn
"""
if fps is None:
# Calculate FPS
now = datetime.now()
dt = (now - self._last_time).total_seconds()
self._last_time = now
if dt > 0:
current_fps = 1.0 / dt
self._fps_history.append(current_fps)
if len(self._fps_history) > 30:
self._fps_history.pop(0)
fps = sum(self._fps_history) / len(self._fps_history)
else:
fps = 0
output = frame.copy()
fps_text = f"FPS: {fps:.1f}"
cv2.putText(
output,
fps_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
(0, 255, 0),
2
)
return output
def draw_stats(
self,
frame: np.ndarray,
stats: Dict[str, Any],
position: Tuple[int, int] = (10, 60),
) -> np.ndarray:
"""
Draw statistics on frame.
Args:
frame: Input frame
stats: Dictionary of stats to display
position: Starting position (x, y)
Returns:
Frame with stats drawn
"""
output = frame.copy()
x, y = position
line_height = 25
for key, value in stats.items():
if isinstance(value, float):
text = f"{key}: {value:.2f}"
else:
text = f"{key}: {value}"
cv2.putText(
output,
text,
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2
)
cv2.putText(
output,
text,
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 0),
1
)
y += line_height
return output
class ObjectListDisplay:
"""Display detected objects in console/log format."""
def __init__(
self,
show_confidence: bool = True,
show_class: bool = True,
show_bbox: bool = True,
show_track_id: bool = False,
):
"""
Initialize object list display.
Args:
show_confidence: Show confidence scores
show_class: Show class names
show_bbox: Show bounding box coordinates
show_track_id: Show tracking ID
"""
self.show_confidence = show_confidence
self.show_class = show_class
self.show_bbox = show_bbox
self.show_track_id = show_track_id
def format_detection(self, detection: Any) -> str:
"""
Format single detection as string.
Args:
detection: Detection object
Returns:
Formatted string
"""
parts = []
if hasattr(detection, 'bbox'):
class_name = detection.class_name
conf = detection.confidence
bbox = detection.bbox
track_id = getattr(detection, 'track_id', None)
x1, y1, x2, y2 = bbox.x1, bbox.y1, bbox.x2, bbox.y2
else:
class_name = detection.get('class_name', str(detection.get('class_id', 0)))
conf = detection.get('confidence', 0)
bbox = detection.get('bbox', [0, 0, 0, 0])
track_id = detection.get('track_id')
x1, y1, x2, y2 = bbox
if self.show_class:
parts.append(f"{class_name}")
if self.show_confidence:
parts.append(f"conf={conf:.2f}")
if self.show_bbox:
parts.append(f"bbox=[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]")
if self.show_track_id and track_id is not None:
parts.append(f"id={track_id}")
return " | ".join(parts)
def format_list(self, detections: List[Any], frame_id: int = 0) -> str:
"""
Format list of detections.
Args:
detections: List of Detection objects
frame_id: Frame index
Returns:
Formatted multi-line string
"""
lines = [f"Frame {frame_id}: {len(detections)} objects"]
for i, det in enumerate(detections):
lines.append(f" [{i}] {self.format_detection(det)}")
return "\n".join(lines)
def print_list(self, detections: List[Any], frame_id: int = 0) -> None:
"""Print formatted detection list to console."""
print(self.format_list(detections, frame_id))