Files
2026-02-04 15:29:36 +07:00

295 lines
7.9 KiB
Python

#!/usr/bin/env python3
"""
Convert PyTorch YOLO model to ONNX format.
This script converts YOLOv9/v8/v5 models to ONNX format for CPU inference.
Usage:
python scripts/convert_to_onnx.py --input models/yolov9t.pt --output models/yolov9t.onnx
"""
import argparse
import sys
from pathlib import Path
def convert_to_onnx(
input_path: str,
output_path: str,
imgsz: int = 640,
simplify: bool = True,
opset: int = 12,
dynamic: bool = False,
half: bool = False,
) -> bool:
"""
Convert PyTorch YOLO model to ONNX.
Args:
input_path: Path to .pt model
output_path: Output ONNX path
imgsz: Input image size
simplify: Simplify ONNX graph
opset: ONNX opset version
dynamic: Enable dynamic input shapes
half: Export as FP16 (requires GPU)
Returns:
True on success
"""
try:
from ultralytics import YOLO
print(f"Loading model: {input_path}")
model = YOLO(input_path)
print(f"Exporting to ONNX...")
print(f" Image size: {imgsz}")
print(f" Simplify: {simplify}")
print(f" Opset: {opset}")
print(f" Dynamic: {dynamic}")
# Export
export_path = model.export(
format='onnx',
imgsz=imgsz,
simplify=simplify,
opset=opset,
dynamic=dynamic,
half=half,
)
# Move to desired location if different
export_path = Path(export_path)
output_path = Path(output_path)
if export_path != output_path:
import shutil
output_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(export_path), str(output_path))
# Verify output
if output_path.exists():
size_mb = output_path.stat().st_size / 1024 / 1024
print(f"\nExport successful!")
print(f" Output: {output_path}")
print(f" Size: {size_mb:.1f} MB")
return True
else:
print("Export failed: output file not created")
return False
except ImportError:
print("Error: ultralytics package not found")
print("Install with: pip install ultralytics")
return False
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return False
def verify_onnx(model_path: str) -> bool:
"""Verify ONNX model is valid."""
try:
import onnx
print(f"\nVerifying ONNX model...")
model = onnx.load(model_path)
onnx.checker.check_model(model)
# Get model info
print(f" Inputs:")
for inp in model.graph.input:
shape = [d.dim_value for d in inp.type.tensor_type.shape.dim]
print(f" {inp.name}: {shape}")
print(f" Outputs:")
for out in model.graph.output:
shape = [d.dim_value for d in out.type.tensor_type.shape.dim]
print(f" {out.name}: {shape}")
print(" Model is valid!")
return True
except ImportError:
print("Note: Install 'onnx' package to verify model: pip install onnx")
return True # Don't fail if onnx not installed
except Exception as e:
print(f" Verification failed: {e}")
return False
def test_inference(model_path: str, imgsz: int = 640) -> bool:
"""Test ONNX model inference."""
try:
import onnxruntime as ort
import numpy as np
print(f"\nTesting inference...")
# Create session
sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
# Get input info
input_name = sess.get_inputs()[0].name
input_shape = sess.get_inputs()[0].shape
# Create dummy input
if input_shape[0] == 'batch_size' or input_shape[0] is None:
batch_size = 1
else:
batch_size = input_shape[0]
dummy_input = np.random.randn(batch_size, 3, imgsz, imgsz).astype(np.float32)
# Run inference
import time
# Warmup
_ = sess.run(None, {input_name: dummy_input})
# Benchmark
times = []
for _ in range(10):
start = time.time()
outputs = sess.run(None, {input_name: dummy_input})
times.append(time.time() - start)
avg_time = sum(times) / len(times) * 1000
fps = 1000 / avg_time
print(f" Inference time: {avg_time:.1f} ms")
print(f" FPS: {fps:.1f}")
print(f" Output shape: {outputs[0].shape}")
return True
except ImportError:
print("Note: Install 'onnxruntime' to test inference: pip install onnxruntime")
return True
except Exception as e:
print(f" Inference test failed: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Convert YOLO model to ONNX format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic conversion
python scripts/convert_to_onnx.py --input models/yolov9t.pt
# With custom output path
python scripts/convert_to_onnx.py --input models/yolov9t.pt --output models/custom.onnx
# With custom image size
python scripts/convert_to_onnx.py --input models/yolov9t.pt --imgsz 416
# Enable dynamic batch size
python scripts/convert_to_onnx.py --input models/yolov9t.pt --dynamic
"""
)
parser.add_argument(
'--input', '-i',
type=str,
required=True,
help='Input PyTorch model path (.pt)'
)
parser.add_argument(
'--output', '-o',
type=str,
default=None,
help='Output ONNX model path (default: same name with .onnx)'
)
parser.add_argument(
'--imgsz',
type=int,
default=640,
help='Input image size (default: 640)'
)
parser.add_argument(
'--opset',
type=int,
default=12,
help='ONNX opset version (default: 12)'
)
parser.add_argument(
'--no-simplify',
action='store_true',
help='Disable ONNX graph simplification'
)
parser.add_argument(
'--dynamic',
action='store_true',
help='Enable dynamic input shapes'
)
parser.add_argument(
'--half',
action='store_true',
help='Export as FP16 (requires GPU)'
)
parser.add_argument(
'--skip-verify',
action='store_true',
help='Skip model verification'
)
parser.add_argument(
'--skip-test',
action='store_true',
help='Skip inference test'
)
args = parser.parse_args()
input_path = Path(args.input)
if not input_path.exists():
print(f"Error: Input file not found: {input_path}")
sys.exit(1)
# Default output path
if args.output:
output_path = args.output
else:
output_path = str(input_path.with_suffix('.onnx'))
print("=" * 60)
print("YOLO to ONNX Conversion")
print("=" * 60)
# Convert
success = convert_to_onnx(
input_path=str(input_path),
output_path=output_path,
imgsz=args.imgsz,
simplify=not args.no_simplify,
opset=args.opset,
dynamic=args.dynamic,
half=args.half,
)
if not success:
sys.exit(1)
# Verify
if not args.skip_verify:
verify_onnx(output_path)
# Test inference
if not args.skip_test:
test_inference(output_path, args.imgsz)
print("\n" + "=" * 60)
print("CONVERSION COMPLETE")
print("=" * 60)
print(f"\nTo use with Frigate-Mini:")
print(f" python scripts/frigate_mini.py --model {output_path} --video input/video.mp4")
if __name__ == '__main__':
main()