restructure config to use Device class directly
Also include device validation
This commit is contained in:
@@ -29,7 +29,7 @@ def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
|
||||
|
||||
def submit_compile_job(
|
||||
model: Any,
|
||||
device_name: str,
|
||||
device: Device,
|
||||
input_specs: dict[str, tuple[tuple[int, ...], str]],
|
||||
target_runtime: str,
|
||||
options: str | None = None,
|
||||
@@ -52,7 +52,7 @@ def submit_compile_job(
|
||||
|
||||
job = hub.submit_compile_job(
|
||||
model=model_arg,
|
||||
device=Device(device_name),
|
||||
device=device,
|
||||
name=job_name,
|
||||
input_specs=input_specs,
|
||||
options=compile_options,
|
||||
@@ -65,14 +65,14 @@ def submit_compile_job(
|
||||
|
||||
def submit_inference_job(
|
||||
model_id: str,
|
||||
device_name: str,
|
||||
device: Device,
|
||||
inputs: dict[str, Any],
|
||||
output_dir: str | Path,
|
||||
job_name: str | None = None,
|
||||
) -> InferenceJobResult:
|
||||
job = hub.submit_inference_job(
|
||||
model=hub.get_model(model_id),
|
||||
device=Device(device_name),
|
||||
device=device,
|
||||
inputs=_dataset_entries(inputs),
|
||||
name=job_name,
|
||||
)
|
||||
@@ -84,13 +84,13 @@ def submit_inference_job(
|
||||
|
||||
def submit_profile_job(
|
||||
model_id: str,
|
||||
device_name: str,
|
||||
device: Device,
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
) -> ProfileJobResult:
|
||||
job = hub.submit_profile_job(
|
||||
model=hub.get_model(model_id),
|
||||
device=Device(device_name),
|
||||
device=device,
|
||||
name=job_name,
|
||||
options=options or "",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user