add ONNX sanitize required for Ultralytics
This commit is contained in:
@@ -227,6 +227,23 @@ The command downloads the job's `model.tar.gz`, finds `model.onnx`, uploads it t
|
||||
compilation, validation, and profiling. The uploaded source model uses the configured
|
||||
`aihub.model_name`.
|
||||
|
||||
The training example sanitizes the Ultralytics ONNX export before saving `model.onnx`. This removes graph input or
|
||||
output names, such as `output0`, that are duplicated in the ONNX `value_info` metadata and rejected by AI Hub.
|
||||
|
||||
For a model already downloaded by a failed upload attempt, sanitize the extracted ONNX file and retry using the local
|
||||
model. Replace the job name in both paths:
|
||||
|
||||
```bash
|
||||
uv run --with onnx python examples/meter-detection/source/sanitize_onnx.py \
|
||||
build/qai-hub/meter-detection/qc-cli-YYYYMMDD-HHMMSS/source/extracted/model.onnx \
|
||||
--output build/qai-hub/meter-detection/model.aihub.onnx
|
||||
|
||||
qc-cli ai-hub upload \
|
||||
examples/meter-detection/data/aihub_calibration \
|
||||
examples/meter-detection/data/inputs.npz \
|
||||
--onnx-path build/qai-hub/meter-detection/model.aihub.onnx
|
||||
```
|
||||
|
||||
If the meter-detection job is still the last training job in `.qc-cli.json`, `--from-job` can be omitted. Keeping it
|
||||
explicit prevents accidentally uploading an artifact from a different training run.
|
||||
|
||||
|
||||
38
examples/meter-detection/source/sanitize_onnx.py
Normal file
38
examples/meter-detection/source/sanitize_onnx.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import onnx # type: ignore[reportMissingImports]
|
||||
|
||||
|
||||
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
|
||||
model = onnx.load(path)
|
||||
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
|
||||
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
|
||||
|
||||
destination = output_path or path
|
||||
if len(retained_value_info) != len(model.graph.value_info):
|
||||
del model.graph.value_info[:]
|
||||
model.graph.value_info.extend(retained_value_info)
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
onnx.save(model, destination)
|
||||
return destination
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("onnx_path", type=Path)
|
||||
parser.add_argument("--output", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
written = sanitize_onnx(args.onnx_path, args.output)
|
||||
print(f"Saved sanitized ONNX model to {written}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sanitize_onnx import sanitize_onnx
|
||||
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
||||
|
||||
|
||||
@@ -103,7 +104,8 @@ def main() -> None:
|
||||
copy_if_exists(trained_weights, model_dir / "best.pt")
|
||||
trained_model = YOLO(str(trained_weights))
|
||||
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
||||
copy_if_exists(onnx_path, model_dir / "model.onnx")
|
||||
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
|
||||
print(f"Saved {saved_onnx_path}")
|
||||
|
||||
metrics = {
|
||||
"model": args.model,
|
||||
@@ -114,7 +116,7 @@ def main() -> None:
|
||||
"patience": args.patience,
|
||||
"data_yaml": str(data_yaml),
|
||||
"weights": str(trained_weights),
|
||||
"onnx": str(onnx_path),
|
||||
"onnx": str(saved_onnx_path),
|
||||
}
|
||||
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||
print(f"Saved model artifacts to {model_dir}")
|
||||
|
||||
Reference in New Issue
Block a user