add ONNX sanitize required for Ultralytics
This commit is contained in:
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sanitize_onnx import sanitize_onnx
|
||||
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
||||
|
||||
|
||||
@@ -103,7 +104,8 @@ def main() -> None:
|
||||
copy_if_exists(trained_weights, model_dir / "best.pt")
|
||||
trained_model = YOLO(str(trained_weights))
|
||||
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
||||
copy_if_exists(onnx_path, model_dir / "model.onnx")
|
||||
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
|
||||
print(f"Saved {saved_onnx_path}")
|
||||
|
||||
metrics = {
|
||||
"model": args.model,
|
||||
@@ -114,7 +116,7 @@ def main() -> None:
|
||||
"patience": args.patience,
|
||||
"data_yaml": str(data_yaml),
|
||||
"weights": str(trained_weights),
|
||||
"onnx": str(onnx_path),
|
||||
"onnx": str(saved_onnx_path),
|
||||
}
|
||||
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||
print(f"Saved model artifacts to {model_dir}")
|
||||
|
||||
Reference in New Issue
Block a user