295 lines
7.9 KiB
Python
295 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Convert PyTorch YOLO model to ONNX format.
|
|
|
|
This script converts YOLOv9/v8/v5 models to ONNX format for CPU inference.
|
|
|
|
Usage:
|
|
python scripts/convert_to_onnx.py --input models/yolov9t.pt --output models/yolov9t.onnx
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def convert_to_onnx(
|
|
input_path: str,
|
|
output_path: str,
|
|
imgsz: int = 640,
|
|
simplify: bool = True,
|
|
opset: int = 12,
|
|
dynamic: bool = False,
|
|
half: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Convert PyTorch YOLO model to ONNX.
|
|
|
|
Args:
|
|
input_path: Path to .pt model
|
|
output_path: Output ONNX path
|
|
imgsz: Input image size
|
|
simplify: Simplify ONNX graph
|
|
opset: ONNX opset version
|
|
dynamic: Enable dynamic input shapes
|
|
half: Export as FP16 (requires GPU)
|
|
|
|
Returns:
|
|
True on success
|
|
"""
|
|
try:
|
|
from ultralytics import YOLO
|
|
|
|
print(f"Loading model: {input_path}")
|
|
model = YOLO(input_path)
|
|
|
|
print(f"Exporting to ONNX...")
|
|
print(f" Image size: {imgsz}")
|
|
print(f" Simplify: {simplify}")
|
|
print(f" Opset: {opset}")
|
|
print(f" Dynamic: {dynamic}")
|
|
|
|
# Export
|
|
export_path = model.export(
|
|
format='onnx',
|
|
imgsz=imgsz,
|
|
simplify=simplify,
|
|
opset=opset,
|
|
dynamic=dynamic,
|
|
half=half,
|
|
)
|
|
|
|
# Move to desired location if different
|
|
export_path = Path(export_path)
|
|
output_path = Path(output_path)
|
|
|
|
if export_path != output_path:
|
|
import shutil
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.move(str(export_path), str(output_path))
|
|
|
|
# Verify output
|
|
if output_path.exists():
|
|
size_mb = output_path.stat().st_size / 1024 / 1024
|
|
print(f"\nExport successful!")
|
|
print(f" Output: {output_path}")
|
|
print(f" Size: {size_mb:.1f} MB")
|
|
return True
|
|
else:
|
|
print("Export failed: output file not created")
|
|
return False
|
|
|
|
except ImportError:
|
|
print("Error: ultralytics package not found")
|
|
print("Install with: pip install ultralytics")
|
|
return False
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def verify_onnx(model_path: str) -> bool:
|
|
"""Verify ONNX model is valid."""
|
|
try:
|
|
import onnx
|
|
|
|
print(f"\nVerifying ONNX model...")
|
|
model = onnx.load(model_path)
|
|
onnx.checker.check_model(model)
|
|
|
|
# Get model info
|
|
print(f" Inputs:")
|
|
for inp in model.graph.input:
|
|
shape = [d.dim_value for d in inp.type.tensor_type.shape.dim]
|
|
print(f" {inp.name}: {shape}")
|
|
|
|
print(f" Outputs:")
|
|
for out in model.graph.output:
|
|
shape = [d.dim_value for d in out.type.tensor_type.shape.dim]
|
|
print(f" {out.name}: {shape}")
|
|
|
|
print(" Model is valid!")
|
|
return True
|
|
|
|
except ImportError:
|
|
print("Note: Install 'onnx' package to verify model: pip install onnx")
|
|
return True # Don't fail if onnx not installed
|
|
except Exception as e:
|
|
print(f" Verification failed: {e}")
|
|
return False
|
|
|
|
|
|
def test_inference(model_path: str, imgsz: int = 640) -> bool:
|
|
"""Test ONNX model inference."""
|
|
try:
|
|
import onnxruntime as ort
|
|
import numpy as np
|
|
|
|
print(f"\nTesting inference...")
|
|
|
|
# Create session
|
|
sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
|
|
|
# Get input info
|
|
input_name = sess.get_inputs()[0].name
|
|
input_shape = sess.get_inputs()[0].shape
|
|
|
|
# Create dummy input
|
|
if input_shape[0] == 'batch_size' or input_shape[0] is None:
|
|
batch_size = 1
|
|
else:
|
|
batch_size = input_shape[0]
|
|
|
|
dummy_input = np.random.randn(batch_size, 3, imgsz, imgsz).astype(np.float32)
|
|
|
|
# Run inference
|
|
import time
|
|
|
|
# Warmup
|
|
_ = sess.run(None, {input_name: dummy_input})
|
|
|
|
# Benchmark
|
|
times = []
|
|
for _ in range(10):
|
|
start = time.time()
|
|
outputs = sess.run(None, {input_name: dummy_input})
|
|
times.append(time.time() - start)
|
|
|
|
avg_time = sum(times) / len(times) * 1000
|
|
fps = 1000 / avg_time
|
|
|
|
print(f" Inference time: {avg_time:.1f} ms")
|
|
print(f" FPS: {fps:.1f}")
|
|
print(f" Output shape: {outputs[0].shape}")
|
|
|
|
return True
|
|
|
|
except ImportError:
|
|
print("Note: Install 'onnxruntime' to test inference: pip install onnxruntime")
|
|
return True
|
|
except Exception as e:
|
|
print(f" Inference test failed: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert YOLO model to ONNX format",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Basic conversion
|
|
python scripts/convert_to_onnx.py --input models/yolov9t.pt
|
|
|
|
# With custom output path
|
|
python scripts/convert_to_onnx.py --input models/yolov9t.pt --output models/custom.onnx
|
|
|
|
# With custom image size
|
|
python scripts/convert_to_onnx.py --input models/yolov9t.pt --imgsz 416
|
|
|
|
# Enable dynamic batch size
|
|
python scripts/convert_to_onnx.py --input models/yolov9t.pt --dynamic
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--input', '-i',
|
|
type=str,
|
|
required=True,
|
|
help='Input PyTorch model path (.pt)'
|
|
)
|
|
parser.add_argument(
|
|
'--output', '-o',
|
|
type=str,
|
|
default=None,
|
|
help='Output ONNX model path (default: same name with .onnx)'
|
|
)
|
|
parser.add_argument(
|
|
'--imgsz',
|
|
type=int,
|
|
default=640,
|
|
help='Input image size (default: 640)'
|
|
)
|
|
parser.add_argument(
|
|
'--opset',
|
|
type=int,
|
|
default=12,
|
|
help='ONNX opset version (default: 12)'
|
|
)
|
|
parser.add_argument(
|
|
'--no-simplify',
|
|
action='store_true',
|
|
help='Disable ONNX graph simplification'
|
|
)
|
|
parser.add_argument(
|
|
'--dynamic',
|
|
action='store_true',
|
|
help='Enable dynamic input shapes'
|
|
)
|
|
parser.add_argument(
|
|
'--half',
|
|
action='store_true',
|
|
help='Export as FP16 (requires GPU)'
|
|
)
|
|
parser.add_argument(
|
|
'--skip-verify',
|
|
action='store_true',
|
|
help='Skip model verification'
|
|
)
|
|
parser.add_argument(
|
|
'--skip-test',
|
|
action='store_true',
|
|
help='Skip inference test'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
input_path = Path(args.input)
|
|
if not input_path.exists():
|
|
print(f"Error: Input file not found: {input_path}")
|
|
sys.exit(1)
|
|
|
|
# Default output path
|
|
if args.output:
|
|
output_path = args.output
|
|
else:
|
|
output_path = str(input_path.with_suffix('.onnx'))
|
|
|
|
print("=" * 60)
|
|
print("YOLO to ONNX Conversion")
|
|
print("=" * 60)
|
|
|
|
# Convert
|
|
success = convert_to_onnx(
|
|
input_path=str(input_path),
|
|
output_path=output_path,
|
|
imgsz=args.imgsz,
|
|
simplify=not args.no_simplify,
|
|
opset=args.opset,
|
|
dynamic=args.dynamic,
|
|
half=args.half,
|
|
)
|
|
|
|
if not success:
|
|
sys.exit(1)
|
|
|
|
# Verify
|
|
if not args.skip_verify:
|
|
verify_onnx(output_path)
|
|
|
|
# Test inference
|
|
if not args.skip_test:
|
|
test_inference(output_path, args.imgsz)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("CONVERSION COMPLETE")
|
|
print("=" * 60)
|
|
print(f"\nTo use with Frigate-Mini:")
|
|
print(f" python scripts/frigate_mini.py --model {output_path} --video input/video.mp4")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|