add ONNX sanitize required for Ultralytics

This commit is contained in:
2026-06-09 12:18:41 -04:00
parent c2d3f44498
commit 9dc6f478bd
3 changed files with 59 additions and 2 deletions

View File

@@ -227,6 +227,23 @@ The command downloads the job's `model.tar.gz`, finds `model.onnx`, uploads it t
compilation, validation, and profiling. The uploaded source model uses the configured compilation, validation, and profiling. The uploaded source model uses the configured
`aihub.model_name`. `aihub.model_name`.
The training example sanitizes the Ultralytics ONNX export before saving `model.onnx`. This removes graph input or
output names, such as `output0`, that are duplicated in the ONNX `value_info` metadata and rejected by AI Hub.
For a model already downloaded by a failed upload attempt, sanitize the extracted ONNX file and retry using the local
model. Replace the job name in both paths:
```bash
uv run --with onnx python examples/meter-detection/source/sanitize_onnx.py \
build/qai-hub/meter-detection/qc-cli-YYYYMMDD-HHMMSS/source/extracted/model.onnx \
--output build/qai-hub/meter-detection/model.aihub.onnx
qc-cli ai-hub upload \
examples/meter-detection/data/aihub_calibration \
examples/meter-detection/data/inputs.npz \
--onnx-path build/qai-hub/meter-detection/model.aihub.onnx
```
If the meter-detection job is still the last training job in `.qc-cli.json`, `--from-job` can be omitted. Keeping it If the meter-detection job is still the last training job in `.qc-cli.json`, `--from-job` can be omitted. Keeping it
explicit prevents accidentally uploading an artifact from a different training run. explicit prevents accidentally uploading an artifact from a different training run.

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
from __future__ import annotations
import argparse
from pathlib import Path
import onnx # type: ignore[reportMissingImports]
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
model = onnx.load(path)
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
destination = output_path or path
if len(retained_value_info) != len(model.graph.value_info):
del model.graph.value_info[:]
model.graph.value_info.extend(retained_value_info)
destination.parent.mkdir(parents=True, exist_ok=True)
onnx.save(model, destination)
return destination
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("onnx_path", type=Path)
parser.add_argument("--output", type=Path)
args = parser.parse_args()
written = sanitize_onnx(args.onnx_path, args.output)
print(f"Saved sanitized ONNX model to {written}")
if __name__ == "__main__":
main()

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Any from typing import Any
import yaml import yaml
from sanitize_onnx import sanitize_onnx
from ultralytics import YOLO # type: ignore[reportMissingImports] from ultralytics import YOLO # type: ignore[reportMissingImports]
@@ -103,7 +104,8 @@ def main() -> None:
copy_if_exists(trained_weights, model_dir / "best.pt") copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights)) trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz)) onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
copy_if_exists(onnx_path, model_dir / "model.onnx") saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
print(f"Saved {saved_onnx_path}")
metrics = { metrics = {
"model": args.model, "model": args.model,
@@ -114,7 +116,7 @@ def main() -> None:
"patience": args.patience, "patience": args.patience,
"data_yaml": str(data_yaml), "data_yaml": str(data_yaml),
"weights": str(trained_weights), "weights": str(trained_weights),
"onnx": str(onnx_path), "onnx": str(saved_onnx_path),
} }
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8") (model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"Saved model artifacts to {model_dir}") print(f"Saved model artifacts to {model_dir}")