Files
2026-02-04 15:29:36 +07:00

281 lines
8.9 KiB
Python

"""
Annotation export utilities for YOLO format output.
"""
import cv2
import json
import shutil
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
import numpy as np
@dataclass
class AnnotationPair:
"""Snapshot and label pair."""
image_path: str
label_path: str
frame_id: int
timestamp: float
detections: List[Any]
camera_name: str = "default"
class AnnotationExporter:
"""Export annotations in YOLO format with snapshot pairs."""
def __init__(
self,
output_dir: str,
class_names: Optional[Dict[int, str]] = None,
image_format: str = "jpg",
image_quality: int = 95,
):
"""
Initialize exporter.
Args:
output_dir: Root output directory
class_names: Mapping of class_id to name
image_format: Output image format (jpg, png)
image_quality: JPEG quality (1-100)
"""
self.output_dir = Path(output_dir)
self.class_names = class_names or {}
self.image_format = image_format
self.image_quality = image_quality
# Create directories
self.images_dir = self.output_dir / "images"
self.labels_dir = self.output_dir / "labels"
self.debug_dir = self.output_dir / "debug"
self.images_dir.mkdir(parents=True, exist_ok=True)
self.labels_dir.mkdir(parents=True, exist_ok=True)
self._pairs: List[AnnotationPair] = []
def save_pair(
self,
frame: np.ndarray,
detections: List[Any],
frame_id: int,
timestamp: float = 0.0,
camera_name: str = "default",
save_debug: bool = False,
) -> Optional[AnnotationPair]:
"""
Save snapshot and label pair.
Args:
frame: Image frame (BGR)
detections: List of Detection objects
frame_id: Frame index
timestamp: Frame timestamp
camera_name: Camera/source name
save_debug: Save debug visualization
Returns:
AnnotationPair or None if no detections
"""
if not detections:
return None
h, w = frame.shape[:2]
# Generate filename with timestamp
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f"{camera_name}_{ts}_{frame_id:06d}"
# Save image
image_path = self.images_dir / f"{base_name}.{self.image_format}"
if self.image_format == "jpg":
cv2.imwrite(str(image_path), frame, [cv2.IMWRITE_JPEG_QUALITY, self.image_quality])
else:
cv2.imwrite(str(image_path), frame)
# Save label
label_path = self.labels_dir / f"{base_name}.txt"
self._write_yolo_label(label_path, detections, w, h)
# Save debug if requested
if save_debug:
self.debug_dir.mkdir(exist_ok=True)
debug_path = self.debug_dir / f"{base_name}_debug.{self.image_format}"
debug_frame = self._draw_detections(frame.copy(), detections)
cv2.imwrite(str(debug_path), debug_frame)
# Create pair record
pair = AnnotationPair(
image_path=str(image_path),
label_path=str(label_path),
frame_id=frame_id,
timestamp=timestamp,
detections=detections,
camera_name=camera_name,
)
self._pairs.append(pair)
return pair
def _write_yolo_label(
self,
path: Path,
detections: List[Any],
img_width: int,
img_height: int,
) -> None:
"""Write YOLO format label file."""
lines = []
for det in detections:
# Get bbox coordinates
if hasattr(det, 'bbox'):
x_c, y_c, w, h = det.bbox.to_yolo(img_width, img_height)
class_id = det.class_id
else:
# Assume dict format
bbox = det.get('bbox', [0, 0, 0, 0])
class_id = det.get('class_id', 0)
x1, y1, x2, y2 = bbox
x_c = ((x1 + x2) / 2) / img_width
y_c = ((y1 + y2) / 2) / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
line = f"{class_id} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}"
lines.append(line)
with open(path, 'w') as f:
f.write('\n'.join(lines))
def _draw_detections(
self,
frame: np.ndarray,
detections: List[Any],
) -> np.ndarray:
"""Draw detection boxes on frame."""
for det in detections:
if hasattr(det, 'bbox'):
x1, y1 = int(det.bbox.x1), int(det.bbox.y1)
x2, y2 = int(det.bbox.x2), int(det.bbox.y2)
class_name = det.class_name
conf = det.confidence
else:
bbox = det.get('bbox', [0, 0, 0, 0])
x1, y1, x2, y2 = [int(x) for x in bbox]
class_name = det.get('class_name', str(det.get('class_id', 0)))
conf = det.get('confidence', 0)
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"{class_name} {conf:.2f}"
(label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1)
cv2.putText(frame, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
return frame
def create_dataset_yaml(
self,
train_ratio: float = 0.8,
shuffle: bool = True,
) -> str:
"""
Create YOLO data.yaml and split dataset.
Args:
train_ratio: Ratio of training data
shuffle: Shuffle data before splitting
Returns:
Path to data.yaml
"""
import random
import yaml
# Get all image-label pairs
images = list(self.images_dir.glob(f"*.{self.image_format}"))
if shuffle:
random.shuffle(images)
# Split
split_idx = int(len(images) * train_ratio)
train_images = images[:split_idx]
val_images = images[split_idx:]
# Create train/val directories
train_img_dir = self.output_dir / "images" / "train"
val_img_dir = self.output_dir / "images" / "val"
train_lbl_dir = self.output_dir / "labels" / "train"
val_lbl_dir = self.output_dir / "labels" / "val"
for d in [train_img_dir, val_img_dir, train_lbl_dir, val_lbl_dir]:
d.mkdir(parents=True, exist_ok=True)
# Move files
for img in train_images:
lbl = self.labels_dir / f"{img.stem}.txt"
shutil.move(str(img), str(train_img_dir / img.name))
if lbl.exists():
shutil.move(str(lbl), str(train_lbl_dir / lbl.name))
for img in val_images:
lbl = self.labels_dir / f"{img.stem}.txt"
shutil.move(str(img), str(val_img_dir / img.name))
if lbl.exists():
shutil.move(str(lbl), str(val_lbl_dir / lbl.name))
# Create data.yaml
data_config = {
'path': str(self.output_dir.absolute()),
'train': 'images/train',
'val': 'images/val',
'names': self.class_names,
'nc': len(self.class_names),
}
yaml_path = self.output_dir / "data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
print(f"Dataset created:")
print(f" Train: {len(train_images)} images")
print(f" Val: {len(val_images)} images")
print(f" Config: {yaml_path}")
return str(yaml_path)
def save_manifest(self) -> str:
"""
Save manifest of all annotation pairs.
Returns:
Path to manifest file
"""
manifest = {
'created': datetime.now().isoformat(),
'total_pairs': len(self._pairs),
'output_dir': str(self.output_dir),
'pairs': [
{
'image': p.image_path,
'label': p.label_path,
'frame_id': p.frame_id,
'timestamp': p.timestamp,
'camera': p.camera_name,
'num_detections': len(p.detections),
}
for p in self._pairs
]
}
manifest_path = self.output_dir / "manifest.json"
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return str(manifest_path)