simplify jobs script

This commit is contained in:
2026-06-01 16:54:06 -04:00
parent 090be14a6a
commit b411be7904

View File

@@ -1,32 +1,26 @@
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, TypedDict
import qai_hub.hub as hub
from qai_hub.client import CompileJob, Device, InferenceJob, Model, ProfileJob, QuantizeDtype, QuantizeJob
def _hub() -> Any: class ModelJobResult(TypedDict):
import qai_hub as hub job: CompileJob | QuantizeJob
job_id: str
return hub model: Model
model_id: str
def _id(obj: Any) -> str: class InferenceJobResult(TypedDict):
for attr in ("model_id", "job_id", "id"): job: InferenceJob
value = getattr(obj, attr, None) job_id: str
if value: outputs: Any
return str(value)
return str(obj)
def _target_model(job: Any) -> Any: class ProfileJobResult(TypedDict):
if hasattr(job, "get_target_model"): job: ProfileJob
return job.get_target_model() job_id: str
model = getattr(job, "target_model", None)
if model is not None:
return model
return job
def get_model(model_id: str) -> Any:
return _hub().get_model(model_id)
def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]: def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
@@ -41,8 +35,7 @@ def submit_compile_job(
options: str | None = None, options: str | None = None,
job_name: str | None = None, job_name: str | None = None,
model_name: str | None = None, model_name: str | None = None,
) -> dict[str, Any]: ) -> ModelJobResult:
hub = _hub()
compile_options = f"--target_runtime {target_runtime}" compile_options = f"--target_runtime {target_runtime}"
if options: if options:
compile_options = f"{compile_options} {options}" compile_options = f"{compile_options} {options}"
@@ -52,22 +45,22 @@ def submit_compile_job(
model_arg = str(model) model_arg = str(model)
elif isinstance(model, str): elif isinstance(model, str):
candidate = Path(model) candidate = Path(model)
model_arg = model if candidate.exists() or candidate.suffix else get_model(model) model_arg = model if candidate.exists() or candidate.suffix else hub.get_model(model)
if model_name and isinstance(model_arg, str) and Path(model_arg).exists(): if model_name and isinstance(model_arg, str) and Path(model_arg).exists():
model_arg = hub.upload_model(model_arg, name=model_name) model_arg = hub.upload_model(model_arg, name=model_name)
job = hub.submit_compile_job( job = hub.submit_compile_job(
model=model_arg, model=model_arg,
device=hub.Device(device_name), device=Device(device_name),
name=job_name, name=job_name,
input_specs=input_specs, input_specs=input_specs,
options=compile_options, options=compile_options,
) )
target_model = _target_model(job) target_model = job.get_target_model()
if target_model is None: if target_model is None:
raise RuntimeError(f"Compile job {_id(job)} did not produce a target model.") raise RuntimeError(f"Compile job {job.job_id} did not produce a target model.")
return {"job": job, "job_id": _id(job), "model": target_model, "model_id": _id(target_model)} return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
def submit_inference_job( def submit_inference_job(
@@ -76,18 +69,17 @@ def submit_inference_job(
inputs: dict[str, Any], inputs: dict[str, Any],
output_dir: str | Path, output_dir: str | Path,
job_name: str | None = None, job_name: str | None = None,
) -> dict[str, Any]: ) -> InferenceJobResult:
hub = _hub()
job = hub.submit_inference_job( job = hub.submit_inference_job(
model=get_model(model_id), model=hub.get_model(model_id),
device=hub.Device(device_name), device=Device(device_name),
inputs=_dataset_entries(inputs), inputs=_dataset_entries(inputs),
name=job_name, name=job_name,
) )
out = Path(output_dir) out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True) out.mkdir(parents=True, exist_ok=True)
data = job.download_output_data(str(out)) data = job.download_output_data(str(out))
return {"job": job, "job_id": _id(job), "outputs": data} return {"job": job, "job_id": str(job.job_id), "outputs": data}
def submit_profile_job( def submit_profile_job(
@@ -95,15 +87,14 @@ def submit_profile_job(
device_name: str, device_name: str,
options: str | None = None, options: str | None = None,
job_name: str | None = None, job_name: str | None = None,
) -> dict[str, Any]: ) -> ProfileJobResult:
hub = _hub()
job = hub.submit_profile_job( job = hub.submit_profile_job(
model=get_model(model_id), model=hub.get_model(model_id),
device=hub.Device(device_name), device=Device(device_name),
name=job_name, name=job_name,
options=options or "", options=options or "",
) )
return {"job": job, "job_id": _id(job)} return {"job": job, "job_id": str(job.job_id)}
def submit_quantize_job( def submit_quantize_job(
@@ -112,33 +103,27 @@ def submit_quantize_job(
options: str | None = None, options: str | None = None,
job_name: str | None = None, job_name: str | None = None,
model_name: str | None = None, model_name: str | None = None,
) -> dict[str, Any]: ) -> ModelJobResult:
hub = _hub()
model_arg = str(model) model_arg = str(model)
if model_name and Path(model_arg).exists(): if model_name and Path(model_arg).exists():
model_arg = hub.upload_model(model_arg, name=model_name) model_arg = hub.upload_model(model_arg, name=model_name)
job = hub.submit_quantize_job( job = hub.submit_quantize_job(
model=model_arg, model=model_arg,
calibration_data=_dataset_entries(calibration_data), calibration_data=_dataset_entries(calibration_data),
weights_dtype=hub.QuantizeDtype.INT8, weights_dtype=QuantizeDtype.INT8,
activations_dtype=hub.QuantizeDtype.INT8, activations_dtype=QuantizeDtype.INT8,
name=job_name, name=job_name,
options=options or "", options=options or "",
) )
target_model = _target_model(job) target_model = job.get_target_model()
if target_model is None: if target_model is None:
raise RuntimeError(f"Quantize job {_id(job)} did not produce a target model.") raise RuntimeError(f"Quantize job {job.job_id} did not produce a target model.")
return {"job": job, "job_id": _id(job), "model": target_model, "model_id": _id(target_model)} return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
def download_model(model_id: str, output_path: str | Path) -> str: def download_model(model_id: str, output_path: str | Path) -> str:
dest = Path(output_path) dest = Path(output_path)
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
model = get_model(model_id) model = hub.get_model(model_id)
if hasattr(model, "download"): result = model.download(str(dest))
result = model.download(str(dest)) return str(result or dest)
return str(result or dest)
if hasattr(model, "download_model"):
result = model.download_model(str(dest))
return str(result or dest)
raise RuntimeError("AI Hub model object does not expose a download method.")