add ONNX sanitize required for Ultralytics

This commit is contained in:
2026-06-09 12:18:41 -04:00
parent c2d3f44498
commit 9dc6f478bd
3 changed files with 59 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Any
import yaml
from sanitize_onnx import sanitize_onnx
from ultralytics import YOLO # type: ignore[reportMissingImports]
@@ -103,7 +104,8 @@ def main() -> None:
copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
copy_if_exists(onnx_path, model_dir / "model.onnx")
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
print(f"Saved {saved_onnx_path}")
metrics = {
"model": args.model,
@@ -114,7 +116,7 @@ def main() -> None:
"patience": args.patience,
"data_yaml": str(data_yaml),
"weights": str(trained_weights),
"onnx": str(onnx_path),
"onnx": str(saved_onnx_path),
}
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"Saved model artifacts to {model_dir}")