simplify jobs script
This commit is contained in:
@@ -1,32 +1,26 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import qai_hub.hub as hub
|
||||
from qai_hub.client import CompileJob, Device, InferenceJob, Model, ProfileJob, QuantizeDtype, QuantizeJob
|
||||
|
||||
|
||||
def _hub() -> Any:
|
||||
import qai_hub as hub
|
||||
|
||||
return hub
|
||||
class ModelJobResult(TypedDict):
|
||||
job: CompileJob | QuantizeJob
|
||||
job_id: str
|
||||
model: Model
|
||||
model_id: str
|
||||
|
||||
|
||||
def _id(obj: Any) -> str:
|
||||
for attr in ("model_id", "job_id", "id"):
|
||||
value = getattr(obj, attr, None)
|
||||
if value:
|
||||
return str(value)
|
||||
return str(obj)
|
||||
class InferenceJobResult(TypedDict):
|
||||
job: InferenceJob
|
||||
job_id: str
|
||||
outputs: Any
|
||||
|
||||
|
||||
def _target_model(job: Any) -> Any:
|
||||
if hasattr(job, "get_target_model"):
|
||||
return job.get_target_model()
|
||||
model = getattr(job, "target_model", None)
|
||||
if model is not None:
|
||||
return model
|
||||
return job
|
||||
|
||||
|
||||
def get_model(model_id: str) -> Any:
|
||||
return _hub().get_model(model_id)
|
||||
class ProfileJobResult(TypedDict):
|
||||
job: ProfileJob
|
||||
job_id: str
|
||||
|
||||
|
||||
def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
|
||||
@@ -41,8 +35,7 @@ def submit_compile_job(
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
hub = _hub()
|
||||
) -> ModelJobResult:
|
||||
compile_options = f"--target_runtime {target_runtime}"
|
||||
if options:
|
||||
compile_options = f"{compile_options} {options}"
|
||||
@@ -52,22 +45,22 @@ def submit_compile_job(
|
||||
model_arg = str(model)
|
||||
elif isinstance(model, str):
|
||||
candidate = Path(model)
|
||||
model_arg = model if candidate.exists() or candidate.suffix else get_model(model)
|
||||
model_arg = model if candidate.exists() or candidate.suffix else hub.get_model(model)
|
||||
|
||||
if model_name and isinstance(model_arg, str) and Path(model_arg).exists():
|
||||
model_arg = hub.upload_model(model_arg, name=model_name)
|
||||
|
||||
job = hub.submit_compile_job(
|
||||
model=model_arg,
|
||||
device=hub.Device(device_name),
|
||||
device=Device(device_name),
|
||||
name=job_name,
|
||||
input_specs=input_specs,
|
||||
options=compile_options,
|
||||
)
|
||||
target_model = _target_model(job)
|
||||
target_model = job.get_target_model()
|
||||
if target_model is None:
|
||||
raise RuntimeError(f"Compile job {_id(job)} did not produce a target model.")
|
||||
return {"job": job, "job_id": _id(job), "model": target_model, "model_id": _id(target_model)}
|
||||
raise RuntimeError(f"Compile job {job.job_id} did not produce a target model.")
|
||||
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
|
||||
|
||||
|
||||
def submit_inference_job(
|
||||
@@ -76,18 +69,17 @@ def submit_inference_job(
|
||||
inputs: dict[str, Any],
|
||||
output_dir: str | Path,
|
||||
job_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
hub = _hub()
|
||||
) -> InferenceJobResult:
|
||||
job = hub.submit_inference_job(
|
||||
model=get_model(model_id),
|
||||
device=hub.Device(device_name),
|
||||
model=hub.get_model(model_id),
|
||||
device=Device(device_name),
|
||||
inputs=_dataset_entries(inputs),
|
||||
name=job_name,
|
||||
)
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
data = job.download_output_data(str(out))
|
||||
return {"job": job, "job_id": _id(job), "outputs": data}
|
||||
return {"job": job, "job_id": str(job.job_id), "outputs": data}
|
||||
|
||||
|
||||
def submit_profile_job(
|
||||
@@ -95,15 +87,14 @@ def submit_profile_job(
|
||||
device_name: str,
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
hub = _hub()
|
||||
) -> ProfileJobResult:
|
||||
job = hub.submit_profile_job(
|
||||
model=get_model(model_id),
|
||||
device=hub.Device(device_name),
|
||||
model=hub.get_model(model_id),
|
||||
device=Device(device_name),
|
||||
name=job_name,
|
||||
options=options or "",
|
||||
)
|
||||
return {"job": job, "job_id": _id(job)}
|
||||
return {"job": job, "job_id": str(job.job_id)}
|
||||
|
||||
|
||||
def submit_quantize_job(
|
||||
@@ -112,33 +103,27 @@ def submit_quantize_job(
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
hub = _hub()
|
||||
) -> ModelJobResult:
|
||||
model_arg = str(model)
|
||||
if model_name and Path(model_arg).exists():
|
||||
model_arg = hub.upload_model(model_arg, name=model_name)
|
||||
job = hub.submit_quantize_job(
|
||||
model=model_arg,
|
||||
calibration_data=_dataset_entries(calibration_data),
|
||||
weights_dtype=hub.QuantizeDtype.INT8,
|
||||
activations_dtype=hub.QuantizeDtype.INT8,
|
||||
weights_dtype=QuantizeDtype.INT8,
|
||||
activations_dtype=QuantizeDtype.INT8,
|
||||
name=job_name,
|
||||
options=options or "",
|
||||
)
|
||||
target_model = _target_model(job)
|
||||
target_model = job.get_target_model()
|
||||
if target_model is None:
|
||||
raise RuntimeError(f"Quantize job {_id(job)} did not produce a target model.")
|
||||
return {"job": job, "job_id": _id(job), "model": target_model, "model_id": _id(target_model)}
|
||||
raise RuntimeError(f"Quantize job {job.job_id} did not produce a target model.")
|
||||
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
|
||||
|
||||
|
||||
def download_model(model_id: str, output_path: str | Path) -> str:
|
||||
dest = Path(output_path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
model = get_model(model_id)
|
||||
if hasattr(model, "download"):
|
||||
model = hub.get_model(model_id)
|
||||
result = model.download(str(dest))
|
||||
return str(result or dest)
|
||||
if hasattr(model, "download_model"):
|
||||
result = model.download_model(str(dest))
|
||||
return str(result or dest)
|
||||
raise RuntimeError("AI Hub model object does not expose a download method.")
|
||||
|
||||
Reference in New Issue
Block a user