247 lines
6.9 KiB
Python
247 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Convert ONNX/PT model to RKNN format.
|
|
|
|
This script converts YOLOv9 models to RKNN format for Rockchip NPU inference.
|
|
|
|
Requirements:
|
|
- rknn-toolkit2 (for x86 host)
|
|
- ONNX model file
|
|
|
|
Usage:
|
|
python scripts/convert_to_rknn.py --input models/yolov9t.onnx --output models/yolov9t.rknn
|
|
python scripts/convert_to_rknn.py --input models/yolov9t.pt --output models/yolov9t.rknn
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def convert_pt_to_onnx(pt_path: str, onnx_path: str, imgsz: int = 640) -> bool:
|
|
"""Convert PyTorch model to ONNX."""
|
|
try:
|
|
from ultralytics import YOLO
|
|
|
|
print(f"Loading PyTorch model: {pt_path}")
|
|
model = YOLO(pt_path)
|
|
|
|
print(f"Exporting to ONNX: {onnx_path}")
|
|
model.export(format='onnx', imgsz=imgsz, simplify=True)
|
|
|
|
# The exported file will be in the same directory with .onnx extension
|
|
exported = Path(pt_path).with_suffix('.onnx')
|
|
if str(exported) != onnx_path:
|
|
import shutil
|
|
shutil.move(str(exported), onnx_path)
|
|
|
|
print("ONNX export complete")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error converting to ONNX: {e}")
|
|
return False
|
|
|
|
|
|
def convert_onnx_to_rknn(
|
|
onnx_path: str,
|
|
rknn_path: str,
|
|
target_platform: str = 'rk3588',
|
|
input_size: tuple = (640, 640),
|
|
quantize: bool = True,
|
|
dataset_path: str = None,
|
|
) -> bool:
|
|
"""Convert ONNX model to RKNN format."""
|
|
try:
|
|
from rknn.api import RKNN
|
|
|
|
print(f"Converting ONNX to RKNN")
|
|
print(f" Input: {onnx_path}")
|
|
print(f" Output: {rknn_path}")
|
|
print(f" Platform: {target_platform}")
|
|
print(f" Input size: {input_size}")
|
|
print(f" Quantize: {quantize}")
|
|
|
|
# Create RKNN object
|
|
rknn = RKNN(verbose=True)
|
|
|
|
# Config
|
|
print("\n[1/5] Configuring RKNN...")
|
|
rknn.config(
|
|
mean_values=[[0, 0, 0]],
|
|
std_values=[[255, 255, 255]],
|
|
target_platform=target_platform,
|
|
quantized_algorithm='normal',
|
|
quantized_method='channel',
|
|
)
|
|
|
|
# Load ONNX model
|
|
print("\n[2/5] Loading ONNX model...")
|
|
ret = rknn.load_onnx(model=onnx_path)
|
|
if ret != 0:
|
|
print(f"Failed to load ONNX model: {ret}")
|
|
return False
|
|
|
|
# Build RKNN model
|
|
print("\n[3/5] Building RKNN model...")
|
|
if quantize and dataset_path:
|
|
# Use calibration dataset for quantization
|
|
ret = rknn.build(do_quantization=True, dataset=dataset_path)
|
|
else:
|
|
# No quantization or use internal random data
|
|
ret = rknn.build(do_quantization=quantize)
|
|
|
|
if ret != 0:
|
|
print(f"Failed to build RKNN model: {ret}")
|
|
return False
|
|
|
|
# Export RKNN model
|
|
print(f"\n[4/5] Exporting RKNN model to {rknn_path}...")
|
|
ret = rknn.export_rknn(rknn_path)
|
|
if ret != 0:
|
|
print(f"Failed to export RKNN model: {ret}")
|
|
return False
|
|
|
|
# Evaluate accuracy (optional)
|
|
print("\n[5/5] Conversion complete!")
|
|
|
|
# Cleanup
|
|
rknn.release()
|
|
|
|
return True
|
|
|
|
except ImportError:
|
|
print("Error: rknn-toolkit2 not installed")
|
|
print("Install with: pip install rknn-toolkit2")
|
|
print("Note: rknn-toolkit2 only works on x86 Linux")
|
|
return False
|
|
except Exception as e:
|
|
print(f"Error converting to RKNN: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert YOLO model to RKNN format",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Convert ONNX to RKNN
|
|
python scripts/convert_to_rknn.py \\
|
|
--input models/yolov9t.onnx \\
|
|
--output models/yolov9t.rknn \\
|
|
--platform rk3588
|
|
|
|
# Convert PT to RKNN (via ONNX)
|
|
python scripts/convert_to_rknn.py \\
|
|
--input models/yolov9t.pt \\
|
|
--output models/yolov9t.rknn \\
|
|
--platform rk3588
|
|
|
|
# With quantization dataset
|
|
python scripts/convert_to_rknn.py \\
|
|
--input models/yolov9t.onnx \\
|
|
--output models/yolov9t.rknn \\
|
|
--platform rk3588 \\
|
|
--dataset calibration_images.txt
|
|
|
|
Supported platforms:
|
|
- rk3588 (RK3588/RK3588S)
|
|
- rk3568 (RK3568)
|
|
- rk3566 (RK3566)
|
|
- rk3562 (RK3562)
|
|
- rv1106 (RV1106)
|
|
- rv1103 (RV1103)
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--input', '-i',
|
|
type=str,
|
|
required=True,
|
|
help='Input model path (.pt or .onnx)'
|
|
)
|
|
parser.add_argument(
|
|
'--output', '-o',
|
|
type=str,
|
|
required=True,
|
|
help='Output RKNN model path (.rknn)'
|
|
)
|
|
parser.add_argument(
|
|
'--platform', '-p',
|
|
type=str,
|
|
default='rk3588',
|
|
choices=['rk3588', 'rk3568', 'rk3566', 'rk3562', 'rv1106', 'rv1103'],
|
|
help='Target Rockchip platform'
|
|
)
|
|
parser.add_argument(
|
|
'--input-size',
|
|
type=int,
|
|
default=640,
|
|
help='Model input size (assumes square)'
|
|
)
|
|
parser.add_argument(
|
|
'--no-quantize',
|
|
action='store_true',
|
|
help='Disable INT8 quantization (use FP16)'
|
|
)
|
|
parser.add_argument(
|
|
'--dataset',
|
|
type=str,
|
|
default=None,
|
|
help='Calibration dataset file (text file with image paths)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
input_path = Path(args.input)
|
|
output_path = Path(args.output)
|
|
|
|
if not input_path.exists():
|
|
print(f"Error: Input file not found: {input_path}")
|
|
sys.exit(1)
|
|
|
|
# Ensure output directory exists
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
print("=" * 60)
|
|
print("RKNN Model Conversion")
|
|
print("=" * 60)
|
|
|
|
# Convert PT to ONNX if needed
|
|
if input_path.suffix == '.pt':
|
|
onnx_path = input_path.with_suffix('.onnx')
|
|
print(f"\nStep 1: Converting PT to ONNX")
|
|
if not convert_pt_to_onnx(str(input_path), str(onnx_path), args.input_size):
|
|
sys.exit(1)
|
|
input_path = onnx_path
|
|
print(f"\nStep 2: Converting ONNX to RKNN")
|
|
else:
|
|
print(f"\nConverting ONNX to RKNN")
|
|
|
|
# Convert ONNX to RKNN
|
|
success = convert_onnx_to_rknn(
|
|
onnx_path=str(input_path),
|
|
rknn_path=str(output_path),
|
|
target_platform=args.platform,
|
|
input_size=(args.input_size, args.input_size),
|
|
quantize=not args.no_quantize,
|
|
dataset_path=args.dataset,
|
|
)
|
|
|
|
if success:
|
|
print("\n" + "=" * 60)
|
|
print("CONVERSION COMPLETE")
|
|
print("=" * 60)
|
|
print(f"Output: {output_path}")
|
|
print(f"Size: {output_path.stat().st_size / 1024 / 1024:.1f} MB")
|
|
else:
|
|
print("\nConversion failed!")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|