231 lines
6.8 KiB
Python
231 lines
6.8 KiB
Python
"""
|
|
Base detector interface.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Tuple, Optional
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
|
|
|
|
@dataclass
|
|
class BBox:
|
|
"""Bounding box."""
|
|
x1: float
|
|
y1: float
|
|
x2: float
|
|
y2: float
|
|
|
|
def to_yolo(self, img_width: int, img_height: int) -> Tuple[float, float, float, float]:
|
|
"""Convert to YOLO format (normalized x_center, y_center, width, height)."""
|
|
x_center = ((self.x1 + self.x2) / 2) / img_width
|
|
y_center = ((self.y1 + self.y2) / 2) / img_height
|
|
width = (self.x2 - self.x1) / img_width
|
|
height = (self.y2 - self.y1) / img_height
|
|
return (x_center, y_center, width, height)
|
|
|
|
def area(self) -> float:
|
|
"""Calculate area in pixels."""
|
|
return (self.x2 - self.x1) * (self.y2 - self.y1)
|
|
|
|
@property
|
|
def width(self) -> float:
|
|
return self.x2 - self.x1
|
|
|
|
@property
|
|
def height(self) -> float:
|
|
return self.y2 - self.y1
|
|
|
|
|
|
@dataclass
|
|
class Detection:
|
|
"""Single detection result."""
|
|
class_id: int
|
|
class_name: str
|
|
confidence: float
|
|
bbox: BBox
|
|
track_id: Optional[int] = None
|
|
|
|
|
|
class BaseDetector(ABC):
|
|
"""Abstract base class for object detectors."""
|
|
|
|
# COCO class names
|
|
COCO_CLASSES = {
|
|
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
|
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
|
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
|
|
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
|
|
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
|
|
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
|
|
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite',
|
|
34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard',
|
|
38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork',
|
|
43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple',
|
|
48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog',
|
|
53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch',
|
|
58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv',
|
|
63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone',
|
|
68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator',
|
|
73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear',
|
|
78: 'hair drier', 79: 'toothbrush'
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
input_size: Tuple[int, int] = (640, 640),
|
|
conf_threshold: float = 0.25,
|
|
nms_threshold: float = 0.45,
|
|
class_names: Optional[dict] = None,
|
|
):
|
|
"""
|
|
Initialize detector.
|
|
|
|
Args:
|
|
model_path: Path to model file
|
|
input_size: Model input size (width, height)
|
|
conf_threshold: Confidence threshold
|
|
nms_threshold: NMS IoU threshold
|
|
class_names: Class ID to name mapping
|
|
"""
|
|
self.model_path = model_path
|
|
self.input_size = input_size
|
|
self.conf_threshold = conf_threshold
|
|
self.nms_threshold = nms_threshold
|
|
self.class_names = class_names or self.COCO_CLASSES
|
|
self.model = None
|
|
|
|
@abstractmethod
|
|
def load_model(self) -> bool:
|
|
"""Load model. Returns True on success."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def detect(self, frame: np.ndarray) -> List[Detection]:
|
|
"""
|
|
Run detection on frame.
|
|
|
|
Args:
|
|
frame: Input image (BGR, HWC)
|
|
|
|
Returns:
|
|
List of Detection objects
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def release(self) -> None:
|
|
"""Release resources."""
|
|
pass
|
|
|
|
def preprocess(self, frame: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Preprocess frame for inference.
|
|
|
|
Args:
|
|
frame: Input frame (BGR, HWC)
|
|
|
|
Returns:
|
|
Preprocessed input tensor
|
|
"""
|
|
import cv2
|
|
|
|
# Resize
|
|
input_width, input_height = self.input_size
|
|
resized = cv2.resize(frame, (input_width, input_height))
|
|
|
|
# BGR to RGB
|
|
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
|
|
|
# Normalize to [0, 1]
|
|
normalized = rgb.astype(np.float32) / 255.0
|
|
|
|
# HWC to CHW
|
|
transposed = normalized.transpose(2, 0, 1)
|
|
|
|
# Add batch dimension
|
|
batched = np.expand_dims(transposed, axis=0)
|
|
|
|
return batched
|
|
|
|
def postprocess(
|
|
self,
|
|
outputs: np.ndarray,
|
|
original_shape: Tuple[int, int],
|
|
) -> List[Detection]:
|
|
"""
|
|
Postprocess model outputs.
|
|
|
|
Args:
|
|
outputs: Raw model outputs
|
|
original_shape: Original frame shape (height, width)
|
|
|
|
Returns:
|
|
List of Detection objects
|
|
"""
|
|
# This is a generic implementation for YOLO-style outputs
|
|
# Override in subclasses for specific model output formats
|
|
|
|
orig_h, orig_w = original_shape
|
|
input_w, input_h = self.input_size
|
|
|
|
detections = []
|
|
|
|
# Assume outputs shape: [1, num_boxes, 5+num_classes] or similar
|
|
# This will vary by model - subclasses should override
|
|
|
|
return detections
|
|
|
|
def nms(
|
|
self,
|
|
boxes: np.ndarray,
|
|
scores: np.ndarray,
|
|
iou_threshold: float = 0.45,
|
|
) -> List[int]:
|
|
"""
|
|
Non-maximum suppression.
|
|
|
|
Args:
|
|
boxes: Array of boxes [N, 4] in xyxy format
|
|
scores: Array of scores [N]
|
|
iou_threshold: IoU threshold
|
|
|
|
Returns:
|
|
List of indices to keep
|
|
"""
|
|
if len(boxes) == 0:
|
|
return []
|
|
|
|
x1 = boxes[:, 0]
|
|
y1 = boxes[:, 1]
|
|
x2 = boxes[:, 2]
|
|
y2 = boxes[:, 3]
|
|
|
|
areas = (x2 - x1) * (y2 - y1)
|
|
order = scores.argsort()[::-1]
|
|
|
|
keep = []
|
|
while order.size > 0:
|
|
i = order[0]
|
|
keep.append(i)
|
|
|
|
if order.size == 1:
|
|
break
|
|
|
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
|
|
w = np.maximum(0, xx2 - xx1)
|
|
h = np.maximum(0, yy2 - yy1)
|
|
|
|
inter = w * h
|
|
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
|
|
|
inds = np.where(iou <= iou_threshold)[0]
|
|
order = order[inds + 1]
|
|
|
|
return keep
|