124 lines
3.6 KiB
Python
124 lines
3.6 KiB
Python
"""
|
|
Ultralytics YOLO detector backend.
|
|
"""
|
|
|
|
import numpy as np
|
|
import logging
|
|
from typing import List, Optional
|
|
|
|
from .base import BaseDetector, Detection, BBox
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class YOLODetector(BaseDetector):
|
|
"""
|
|
Ultralytics YOLO detector.
|
|
|
|
Supports YOLOv5, YOLOv8, YOLOv9, etc.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
conf_threshold: float = 0.25,
|
|
nms_threshold: float = 0.45,
|
|
device: str = "cpu",
|
|
class_names: Optional[dict] = None,
|
|
):
|
|
"""
|
|
Initialize YOLO detector.
|
|
|
|
Args:
|
|
model_path: Path to .pt model file
|
|
conf_threshold: Confidence threshold
|
|
nms_threshold: NMS IoU threshold
|
|
device: Device to run on ('cpu', 'cuda', '0', etc.)
|
|
class_names: Class ID to name mapping
|
|
"""
|
|
super().__init__(
|
|
model_path=model_path,
|
|
conf_threshold=conf_threshold,
|
|
nms_threshold=nms_threshold,
|
|
class_names=class_names,
|
|
)
|
|
self.device = device
|
|
|
|
def load_model(self) -> bool:
|
|
"""Load YOLO model."""
|
|
try:
|
|
from ultralytics import YOLO
|
|
|
|
logger.info(f"Loading YOLO model: {self.model_path}")
|
|
self.model = YOLO(self.model_path)
|
|
self.model.to(self.device)
|
|
|
|
# Update class names from model if available
|
|
if hasattr(self.model, 'names'):
|
|
self.class_names = self.model.names
|
|
|
|
logger.info(f"YOLO model loaded on {self.device}")
|
|
return True
|
|
|
|
except ImportError:
|
|
logger.error("ultralytics package not found. Install with: pip install ultralytics")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to load YOLO model: {e}")
|
|
return False
|
|
|
|
def detect(self, frame: np.ndarray) -> List[Detection]:
|
|
"""
|
|
Run detection on frame.
|
|
|
|
Args:
|
|
frame: Input image (BGR, HWC)
|
|
|
|
Returns:
|
|
List of Detection objects
|
|
"""
|
|
if self.model is None:
|
|
logger.warning("Model not loaded")
|
|
return []
|
|
|
|
try:
|
|
# Run inference
|
|
results = self.model.predict(
|
|
frame,
|
|
conf=self.conf_threshold,
|
|
iou=self.nms_threshold,
|
|
verbose=False,
|
|
)
|
|
|
|
detections = []
|
|
|
|
for result in results:
|
|
if result.boxes is None:
|
|
continue
|
|
|
|
for box in result.boxes:
|
|
class_id = int(box.cls[0].item())
|
|
confidence = float(box.conf[0].item())
|
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
|
|
|
class_name = self.class_names.get(class_id, str(class_id))
|
|
|
|
detection = Detection(
|
|
class_id=class_id,
|
|
class_name=class_name,
|
|
confidence=confidence,
|
|
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
|
)
|
|
detections.append(detection)
|
|
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Detection error: {e}")
|
|
return []
|
|
|
|
def release(self) -> None:
|
|
"""Release resources."""
|
|
self.model = None
|
|
logger.info("YOLO detector released")
|