#!/usr/bin/env python3 """Remove ONNX value_info entries that duplicate graph inputs or outputs.""" from __future__ import annotations import argparse from pathlib import Path import onnx # type: ignore[reportMissingImports] def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path: model = onnx.load(path) io_names = {value.name for value in (*model.graph.input, *model.graph.output)} retained_value_info = [value for value in model.graph.value_info if value.name not in io_names] destination = output_path or path if len(retained_value_info) != len(model.graph.value_info): del model.graph.value_info[:] model.graph.value_info.extend(retained_value_info) destination.parent.mkdir(parents=True, exist_ok=True) onnx.save(model, destination) return destination def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("onnx_path", type=Path) parser.add_argument("--output", type=Path) args = parser.parse_args() written = sanitize_onnx(args.onnx_path, args.output) print(f"Saved sanitized ONNX model to {written}") if __name__ == "__main__": main()