284 lines
9.4 KiB
Python
284 lines
9.4 KiB
Python
"""
|
|
ONNX Runtime detector backend.
|
|
"""
|
|
|
|
import numpy as np
|
|
import logging
|
|
from typing import List, Tuple, Optional
|
|
|
|
from .base import BaseDetector, Detection, BBox
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ONNXDetector(BaseDetector):
|
|
"""
|
|
ONNX Runtime-based YOLO detector.
|
|
|
|
Supports CPU and CUDA execution providers.
|
|
This is the recommended backend for CPU-only inference.
|
|
|
|
Features:
|
|
- Cross-platform (Linux, Windows, macOS, ARM)
|
|
- No special hardware required
|
|
- Optimized CPU inference with threading
|
|
- Optional CUDA support
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
input_size: Tuple[int, int] = (640, 640),
|
|
conf_threshold: float = 0.25,
|
|
nms_threshold: float = 0.45,
|
|
device: str = "cpu",
|
|
num_threads: int = 0,
|
|
optimization_level: str = "all",
|
|
class_names: Optional[dict] = None,
|
|
):
|
|
"""
|
|
Initialize ONNX detector.
|
|
|
|
Args:
|
|
model_path: Path to .onnx model file
|
|
input_size: Model input size (width, height)
|
|
conf_threshold: Confidence threshold
|
|
nms_threshold: NMS IoU threshold
|
|
device: Device ('cpu' or 'cuda')
|
|
num_threads: CPU threads (0 = auto based on CPU cores)
|
|
optimization_level: Graph optimization ('none', 'basic', 'extended', 'all')
|
|
class_names: Class ID to name mapping
|
|
"""
|
|
super().__init__(
|
|
model_path=model_path,
|
|
input_size=input_size,
|
|
conf_threshold=conf_threshold,
|
|
nms_threshold=nms_threshold,
|
|
class_names=class_names,
|
|
)
|
|
self.device = device
|
|
self.num_threads = num_threads
|
|
self.optimization_level = optimization_level
|
|
self.session = None
|
|
self.input_name = None
|
|
self.output_names = None
|
|
|
|
def load_model(self) -> bool:
|
|
"""Load ONNX model."""
|
|
try:
|
|
import onnxruntime as ort
|
|
|
|
# Select execution providers
|
|
if self.device == "cuda":
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
else:
|
|
providers = ['CPUExecutionProvider']
|
|
|
|
logger.info(f"Loading ONNX model: {self.model_path}")
|
|
logger.info(f" Device: {self.device}")
|
|
logger.info(f" Threads: {self.num_threads if self.num_threads > 0 else 'auto'}")
|
|
|
|
# Create session options
|
|
sess_options = ort.SessionOptions()
|
|
|
|
# Set optimization level
|
|
opt_levels = {
|
|
'none': ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
|
|
'basic': ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
|
|
'extended': ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
|
|
'all': ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
|
|
}
|
|
sess_options.graph_optimization_level = opt_levels.get(
|
|
self.optimization_level,
|
|
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
)
|
|
|
|
# Set CPU threading options
|
|
if self.num_threads > 0:
|
|
sess_options.intra_op_num_threads = self.num_threads
|
|
sess_options.inter_op_num_threads = self.num_threads
|
|
|
|
# Enable memory optimization
|
|
sess_options.enable_mem_pattern = True
|
|
sess_options.enable_cpu_mem_arena = True
|
|
|
|
# Create session
|
|
self.session = ort.InferenceSession(
|
|
self.model_path,
|
|
sess_options=sess_options,
|
|
providers=providers,
|
|
)
|
|
|
|
# Get input/output info
|
|
self.input_name = self.session.get_inputs()[0].name
|
|
self.output_names = [o.name for o in self.session.get_outputs()]
|
|
|
|
# Get input shape
|
|
input_shape = self.session.get_inputs()[0].shape
|
|
if len(input_shape) == 4:
|
|
self.input_size = (input_shape[3], input_shape[2]) # width, height
|
|
|
|
actual_provider = self.session.get_providers()[0]
|
|
logger.info(f"ONNX model loaded successfully")
|
|
logger.info(f" Provider: {actual_provider}")
|
|
logger.info(f" Input size: {self.input_size}")
|
|
|
|
return True
|
|
|
|
except ImportError:
|
|
logger.error("onnxruntime not found. Install with: pip install onnxruntime")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to load ONNX model: {e}")
|
|
return False
|
|
|
|
def detect(self, frame: np.ndarray) -> List[Detection]:
|
|
"""
|
|
Run detection on frame.
|
|
|
|
Args:
|
|
frame: Input image (BGR, HWC)
|
|
|
|
Returns:
|
|
List of Detection objects
|
|
"""
|
|
if self.session is None:
|
|
logger.warning("ONNX session not initialized")
|
|
return []
|
|
|
|
try:
|
|
orig_h, orig_w = frame.shape[:2]
|
|
|
|
# Preprocess
|
|
input_tensor, ratio, pad = self._preprocess(frame)
|
|
|
|
# Run inference
|
|
outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
|
|
|
|
# Postprocess
|
|
detections = self._postprocess(outputs, (orig_h, orig_w), ratio, pad)
|
|
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"ONNX inference error: {e}")
|
|
return []
|
|
|
|
def _preprocess(self, frame: np.ndarray) -> Tuple[np.ndarray, float, Tuple[float, float]]:
|
|
"""Preprocess frame for ONNX inference."""
|
|
import cv2
|
|
|
|
input_w, input_h = self.input_size
|
|
orig_h, orig_w = frame.shape[:2]
|
|
|
|
# Calculate scale
|
|
ratio = min(input_w / orig_w, input_h / orig_h)
|
|
new_w = int(orig_w * ratio)
|
|
new_h = int(orig_h * ratio)
|
|
|
|
# Resize
|
|
resized = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# Pad
|
|
pad_w = (input_w - new_w) / 2
|
|
pad_h = (input_h - new_h) / 2
|
|
|
|
top = int(round(pad_h - 0.1))
|
|
bottom = int(round(pad_h + 0.1))
|
|
left = int(round(pad_w - 0.1))
|
|
right = int(round(pad_w + 0.1))
|
|
|
|
padded = cv2.copyMakeBorder(
|
|
resized, top, bottom, left, right,
|
|
cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
|
)
|
|
|
|
# BGR to RGB
|
|
rgb = cv2.cvtColor(padded, cv2.COLOR_BGR2RGB)
|
|
|
|
# Normalize
|
|
normalized = rgb.astype(np.float32) / 255.0
|
|
|
|
# HWC to NCHW
|
|
transposed = normalized.transpose(2, 0, 1)
|
|
batched = np.expand_dims(transposed, axis=0)
|
|
|
|
return batched, ratio, (pad_w, pad_h)
|
|
|
|
def _postprocess(
|
|
self,
|
|
outputs: list,
|
|
original_shape: Tuple[int, int],
|
|
ratio: float,
|
|
pad: Tuple[float, float],
|
|
) -> List[Detection]:
|
|
"""Postprocess ONNX outputs."""
|
|
detections = []
|
|
orig_h, orig_w = original_shape
|
|
pad_w, pad_h = pad
|
|
|
|
# Handle different output formats
|
|
output = outputs[0]
|
|
|
|
if output.ndim == 3:
|
|
output = output[0]
|
|
|
|
# Transpose if needed (num_classes+4 x num_boxes -> num_boxes x num_classes+4)
|
|
if output.shape[0] < output.shape[1]:
|
|
output = output.T
|
|
|
|
for row in output:
|
|
if len(row) < 5:
|
|
continue
|
|
|
|
# Parse based on format
|
|
if len(row) == 85: # YOLOv5 format with obj_conf
|
|
x, y, w, h, obj_conf = row[:5]
|
|
class_confs = row[5:]
|
|
class_id = np.argmax(class_confs)
|
|
confidence = obj_conf * class_confs[class_id]
|
|
else: # YOLOv8/v9 format without obj_conf
|
|
x, y, w, h = row[:4]
|
|
class_confs = row[4:]
|
|
class_id = np.argmax(class_confs)
|
|
confidence = class_confs[class_id]
|
|
|
|
if confidence < self.conf_threshold:
|
|
continue
|
|
|
|
# Convert to xyxy and scale back
|
|
x1 = (x - w / 2 - pad_w) / ratio
|
|
y1 = (y - h / 2 - pad_h) / ratio
|
|
x2 = (x + w / 2 - pad_w) / ratio
|
|
y2 = (y + h / 2 - pad_h) / ratio
|
|
|
|
# Clip
|
|
x1 = max(0, min(orig_w, x1))
|
|
y1 = max(0, min(orig_h, y1))
|
|
x2 = max(0, min(orig_w, x2))
|
|
y2 = max(0, min(orig_h, y2))
|
|
|
|
class_name = self.class_names.get(int(class_id), str(class_id))
|
|
|
|
detection = Detection(
|
|
class_id=int(class_id),
|
|
class_name=class_name,
|
|
confidence=float(confidence),
|
|
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
|
)
|
|
detections.append(detection)
|
|
|
|
# Apply NMS
|
|
if detections:
|
|
boxes = np.array([[d.bbox.x1, d.bbox.y1, d.bbox.x2, d.bbox.y2] for d in detections])
|
|
scores = np.array([d.confidence for d in detections])
|
|
keep = self.nms(boxes, scores, self.nms_threshold)
|
|
detections = [detections[i] for i in keep]
|
|
|
|
return detections
|
|
|
|
def release(self) -> None:
|
|
"""Release ONNX session."""
|
|
self.session = None
|
|
logger.info("ONNX detector released")
|