39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import onnx # type: ignore[reportMissingImports]
|
|
|
|
|
|
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
|
|
model = onnx.load(path)
|
|
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
|
|
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
|
|
|
|
destination = output_path or path
|
|
if len(retained_value_info) != len(model.graph.value_info):
|
|
del model.graph.value_info[:]
|
|
model.graph.value_info.extend(retained_value_info)
|
|
|
|
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
onnx.save(model, destination)
|
|
return destination
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("onnx_path", type=Path)
|
|
parser.add_argument("--output", type=Path)
|
|
args = parser.parse_args()
|
|
|
|
written = sanitize_onnx(args.onnx_path, args.output)
|
|
print(f"Saved sanitized ONNX model to {written}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|