diff --git a/convert_to_rknn.py b/convert_to_rknn.py new file mode 100644 index 0000000..04fb89d --- /dev/null +++ b/convert_to_rknn.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +""" + python3 convert_to_rknn.py \ + --input krg-tuang-atas-yolov9t-best.pt \ + --target-platform rk3588 \ + --no-simplify \ + --mean-values 0,0,0 \ + --std-values 255,255,255 \ + --dynamic-shapes "1,3,320,320" +""" + +from __future__ import annotations + +import argparse +import shutil +import sys +from pathlib import Path + + +def _parse_list(value: str, name: str) -> list[float]: + if value is None: + return [] + try: + return [float(v.strip()) for v in value.split(",") if v.strip() != ""] + except ValueError as exc: + raise argparse.ArgumentTypeError( + f"Invalid {name} list: {value!r}. Use comma-separated numbers." + ) from exc + + +def _add_bool_arg(parser: argparse.ArgumentParser, name: str, default: bool) -> None: + group = parser.add_mutually_exclusive_group() + group.add_argument(f"--{name}", dest=name, action="store_true") + group.add_argument(f"--no-{name}", dest=name, action="store_false") + parser.set_defaults(**{name: default}) + + +def export_pt_to_onnx( + pt_path: Path, + onnx_output: Path | None, + imgsz: int, + opset: int, + simplify: bool, + dynamic: bool, + half: bool, + verbose: bool, +) -> Path: + try: + from ultralytics import YOLO # type: ignore + except Exception as exc: + raise RuntimeError( + "Ultralytics is required to export .pt to ONNX. " + "Install with: pip install ultralytics" + ) from exc + + model = YOLO(str(pt_path)) + export_result = model.export( + format="onnx", + imgsz=imgsz, + opset=opset, + simplify=simplify, + dynamic=dynamic, + half=half, + verbose=verbose, + ) + + exported_path = None + if isinstance(export_result, (str, Path)): + exported_path = Path(export_result) + + if exported_path is None: + candidate = pt_path.with_suffix(".onnx") + if candidate.exists(): + exported_path = candidate + else: + onnx_files = sorted( + pt_path.parent.glob("*.onnx"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + if onnx_files: + exported_path = onnx_files[0] + + if exported_path is None or not exported_path.exists(): + raise RuntimeError("ONNX export did not produce a file.") + + if onnx_output is not None: + onnx_output = onnx_output.resolve() + if exported_path.resolve() != onnx_output: + onnx_output.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(exported_path, onnx_output) + return onnx_output + + return exported_path + + +def convert_onnx_to_rknn( + onnx_path: Path, + rknn_output: Path, + dataset: Path | None, + quantize: bool, + target_platform: str | None, + mean_values: list[float], + std_values: list[float], + quantized_dtype: str | None, + verbose: bool, + dynamic_shapes: list[list[int]] | None, +) -> None: + try: + import onnx # type: ignore + except ImportError as exc: + raise RuntimeError( + "onnx is required by rknn-toolkit2. Install with: pip install 'onnx>=1.16.0'." + ) from exc + + if not hasattr(onnx, "mapping") and hasattr(onnx, "_mapping"): + onnx.mapping = onnx._mapping # type: ignore[attr-defined] + + if not hasattr(onnx, "mapping"): + raise RuntimeError( + "Incompatible onnx version detected. rknn-toolkit2 expects onnx.mapping. " + "Install a compatible version (e.g. pip install 'onnx>=1.16.0')." + ) + + # rknn-toolkit2 expects onnx.mapping.TENSOR_TYPE_TO_NP_TYPE and NP_TYPE_TO_TENSOR_TYPE + if not hasattr(onnx.mapping, "TENSOR_TYPE_TO_NP_TYPE"): + try: + import numpy as np # type: ignore + mapping = {} + for value in onnx.TensorProto.DataType.values(): + try: + np_type = onnx.helper.tensor_dtype_to_np_dtype(value) + if isinstance(np_type, type) or isinstance(np_type, np.dtype): + mapping[value] = np_type + except Exception: + continue + onnx.mapping.TENSOR_TYPE_TO_NP_TYPE = mapping # type: ignore[attr-defined] + if hasattr(onnx, "_mapping"): + onnx._mapping.TENSOR_TYPE_TO_NP_TYPE = mapping # type: ignore[attr-defined] + except Exception as exc: + raise RuntimeError( + "onnx mapping is missing TENSOR_TYPE_TO_NP_TYPE and could not be built. " + "Try onnx==1.17.* or install numpy." + ) from exc + + if not hasattr(onnx.mapping, "NP_TYPE_TO_TENSOR_TYPE"): + try: + import numpy as np # type: ignore + inverse = {} + for tensor_type, np_type in onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.items(): + try: + dtype = np.dtype(np_type) + inverse[dtype] = tensor_type + inverse[dtype.type] = tensor_type + except Exception: + continue + onnx.mapping.NP_TYPE_TO_TENSOR_TYPE = inverse # type: ignore[attr-defined] + if hasattr(onnx, "_mapping"): + onnx._mapping.NP_TYPE_TO_TENSOR_TYPE = inverse # type: ignore[attr-defined] + except Exception as exc: + raise RuntimeError( + "onnx mapping is missing NP_TYPE_TO_TENSOR_TYPE and could not be built. " + "Try onnx==1.17.* or install numpy." + ) from exc + + try: + from rknn.api import RKNN # type: ignore + except Exception as exc: + raise RuntimeError( + "RKNN Toolkit is required. Install rknn-toolkit or rknn-toolkit2." + ) from exc + + rknn = RKNN(verbose=verbose) + + config_kwargs = {} + if target_platform: + config_kwargs["target_platform"] = target_platform + if mean_values: + config_kwargs["mean_values"] = [mean_values] + if std_values: + config_kwargs["std_values"] = [std_values] + if quantized_dtype: + config_kwargs["quantized_dtype"] = quantized_dtype + if dynamic_shapes: + # rknn-toolkit2 expects: [[[shape]]] where shape is the full shape list + # Format: list(inputs) -> list(list(one_shape)) -> shape is list of dims + # For single input with shape [1, 3, 320, 320]: [[[1, 3, 320, 320]]] + first_shape = dynamic_shapes[0] + config_kwargs["dynamic_input"] = [[first_shape]] + if config_kwargs: + rknn.config(**config_kwargs) + + ret = rknn.load_onnx(model=str(onnx_path)) + if ret != 0: + raise RuntimeError("Failed to load ONNX model.") + + if quantize and dataset is None: + raise RuntimeError("Quantization enabled but no dataset was provided.") + + ret = rknn.build( + do_quantization=quantize, + dataset=str(dataset) if dataset else None, + ) + if ret != 0: + raise RuntimeError("Failed to build RKNN model.") + + rknn_output.parent.mkdir(parents=True, exist_ok=True) + ret = rknn.export_rknn(str(rknn_output)) + if ret != 0: + raise RuntimeError("Failed to export RKNN model.") + + rknn.release() + + +def main() -> int: + parser = argparse.ArgumentParser(description="Convert .pt to RKNN.") + parser.add_argument("--input", required=True, help="Path to .pt file.") + parser.add_argument( + "--output", + help="Path to output .rknn file. Defaults next to input.", + ) + parser.add_argument( + "--keep-onnx", + action="store_true", + help="Keep the intermediate ONNX file.", + ) + parser.add_argument("--imgsz", type=int, default=320, help="Export image size.") + parser.add_argument("--opset", type=int, default=11, help="ONNX opset version.") + _add_bool_arg(parser, "simplify", True) + _add_bool_arg(parser, "dynamic", False) + _add_bool_arg(parser, "half", False) + parser.add_argument("--dataset", help="Calibration dataset for quantization.") + parser.add_argument( + "--quantize", + action="store_true", + help="Enable INT8 quantization (requires --dataset).", + ) + parser.add_argument( + "--no-quantize", + dest="quantize", + action="store_false", + help="Disable quantization even if dataset is provided.", + ) + parser.set_defaults(quantize=None) + parser.add_argument( + "--target-platform", + help="Target platform, e.g. rk3566, rk3588, rv1109, rv1126.", + ) + parser.add_argument( + "--mean-values", + help="Comma-separated mean values, e.g. 0,0,0.", + ) + parser.add_argument( + "--std-values", + help="Comma-separated std values, e.g. 255,255,255.", + ) + parser.add_argument( + "--quantized-dtype", + help="Quantized dtype, e.g. asymmetric_quantized-8 or symmetric_quantized-8.", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose logs.") + parser.add_argument( + "--dynamic-shapes", + help=( + "Enable RKNN dynamic input shapes. Format: " + "'1,3,320,320' (single shape per input). " + "If multiple shapes provided (semicolon-separated), only the first is used." + ), + ) + + args = parser.parse_args() + + input_path = Path(args.input).expanduser().resolve() + if not input_path.exists(): + print(f"Input file not found: {input_path}", file=sys.stderr) + return 1 + + rknn_output = ( + Path(args.output).expanduser().resolve() + if args.output + else input_path.with_suffix(".rknn") + ) + + if input_path.suffix.lower() != ".pt": + print("Input must be a .pt file.", file=sys.stderr) + return 1 + + onnx_path = input_path.with_suffix(".onnx") + if args.dynamic_shapes and not args.dynamic: + print("⚠️ --dynamic-shapes set; enabling dynamic ONNX export.") + args.dynamic = True + + onnx_path = export_pt_to_onnx( + pt_path=input_path, + onnx_output=onnx_path, + imgsz=args.imgsz, + opset=args.opset, + simplify=args.simplify, + dynamic=args.dynamic, + half=args.half, + verbose=args.verbose, + ) + + mean_values = _parse_list(args.mean_values, "mean-values") if args.mean_values else [] + std_values = _parse_list(args.std_values, "std-values") if args.std_values else [] + if mean_values and std_values and len(mean_values) != len(std_values): + print("mean-values and std-values must have the same length.", file=sys.stderr) + return 1 + + dataset = Path(args.dataset).expanduser().resolve() if args.dataset else None + if dataset and not dataset.exists(): + print(f"Dataset file not found: {dataset}", file=sys.stderr) + return 1 + + quantize = args.quantize + if quantize is None: + quantize = dataset is not None + if quantize and dataset is None: + print("⚠️ --quantize set but no dataset provided; continuing without quantization.") + quantize = False + + dynamic_shapes = None + if args.dynamic_shapes: + try: + parsed_shapes = [ + [int(v) for v in shape.split(",") if v.strip()] + for shape in args.dynamic_shapes.split(";") + if shape.strip() + ] + if not parsed_shapes: + print("Invalid --dynamic-shapes format.", file=sys.stderr) + return 1 + if len(parsed_shapes) > 1: + print( + f"⚠️ Multiple shapes provided, but RKNN only supports one shape per input. " + f"Using first shape: {parsed_shapes[0]}", + file=sys.stderr, + ) + dynamic_shapes = parsed_shapes + except ValueError: + print("Invalid --dynamic-shapes format.", file=sys.stderr) + return 1 + + try: + convert_onnx_to_rknn( + onnx_path=onnx_path, + rknn_output=rknn_output, + dataset=dataset, + quantize=quantize, + target_platform=args.target_platform, + mean_values=mean_values, + std_values=std_values, + quantized_dtype=args.quantized_dtype, + verbose=args.verbose, + dynamic_shapes=dynamic_shapes, + ) + except RuntimeError as exc: + print(str(exc), file=sys.stderr) + return 1 +# finally: +# if onnx_path.exists() and not args.keep_onnx: +# try: +# onnx_path.unlink() +# except OSError: +# pass + + print(f"RKNN saved to: {rknn_output}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..86159ef --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +ultralytics +rknn-toolkit2 +onnx>=1.16.0 \ No newline at end of file