add ONNX sanitize required for Ultralytics

This commit is contained in:
2026-06-09 12:18:41 -04:00
parent c2d3f44498
commit 9dc6f478bd
3 changed files with 59 additions and 2 deletions

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
from __future__ import annotations
import argparse
from pathlib import Path
import onnx # type: ignore[reportMissingImports]
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
model = onnx.load(path)
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
destination = output_path or path
if len(retained_value_info) != len(model.graph.value_info):
del model.graph.value_info[:]
model.graph.value_info.extend(retained_value_info)
destination.parent.mkdir(parents=True, exist_ok=True)
onnx.save(model, destination)
return destination
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("onnx_path", type=Path)
parser.add_argument("--output", type=Path)
args = parser.parse_args()
written = sanitize_onnx(args.onnx_path, args.output)
print(f"Saved sanitized ONNX model to {written}")
if __name__ == "__main__":
main()

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Any
import yaml
from sanitize_onnx import sanitize_onnx
from ultralytics import YOLO # type: ignore[reportMissingImports]
@@ -103,7 +104,8 @@ def main() -> None:
copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
copy_if_exists(onnx_path, model_dir / "model.onnx")
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
print(f"Saved {saved_onnx_path}")
metrics = {
"model": args.model,
@@ -114,7 +116,7 @@ def main() -> None:
"patience": args.patience,
"data_yaml": str(data_yaml),
"weights": str(trained_weights),
"onnx": str(onnx_path),
"onnx": str(saved_onnx_path),
}
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"Saved model artifacts to {model_dir}")