From 6bc25dc183d6b7da30b1b7910974443359214da1 Mon Sep 17 00:00:00 2001 From: slalom Date: Thu, 4 Jun 2026 17:28:17 -0400 Subject: [PATCH] restructure config to use Device class directly Also include device validation --- README.md | 3 ++- examples/ai-hub/README.md | 3 ++- src/commands/ai_hub.py | 34 +++++++++++++++++++++++++++++++++- src/config.py | 12 ++++++++++-- src/qualcomm/aihub_jobs.py | 12 ++++++------ 5 files changed, 53 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index d71e9f9..14faaac 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ sagemaker: hyperparameters: {} aihub: - device: Samsung Galaxy S25 (Family) + device: + name: Samsung Galaxy S25 (Family) target_runtime: tflite input_specs: {} # Required before running qc-cli ai-hub commands job_name: null # Optional prefix for AI Hub Workbench jobs diff --git a/examples/ai-hub/README.md b/examples/ai-hub/README.md index 947598f..dce4bd1 100644 --- a/examples/ai-hub/README.md +++ b/examples/ai-hub/README.md @@ -28,7 +28,8 @@ Your `config.yaml` must include AI Hub settings: ```yaml aihub: - device: Samsung Galaxy S25 (Family) + device: + name: Samsung Galaxy S25 (Family) target_runtime: tflite input_specs: input: [[1, 3, 160, 160], float32] diff --git a/src/commands/ai_hub.py b/src/commands/ai_hub.py index 3ef3335..d42b31e 100644 --- a/src/commands/ai_hub.py +++ b/src/commands/ai_hub.py @@ -4,7 +4,9 @@ from enum import StrEnum from pathlib import Path from typing import Any +import qai_hub.hub as hub import typer +from qai_hub.client import Device from src import state as state_ops from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg @@ -99,6 +101,33 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo return resolved +def _device_selector(device: Device) -> str: + parts: list[str] = [] + if device.name: + parts.append(f"name={device.name!r}") + if device.os: + parts.append(f"os={device.os!r}") + if device.attributes: + parts.append(f"attributes={device.attributes!r}") + return ", ".join(parts) if parts else "empty selector" + + +def _validate_device(cfg: Config) -> None: + device = cfg.aihub.device + try: + matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes) + except Exception as e: + CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]") + raise typer.Exit(1) + + if matches: + return + + CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]") + CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.") + raise typer.Exit(1) + + def _quantize_step( cfg: Config, config_path: str, @@ -156,6 +185,7 @@ def _compile_step( prefer_quantized: bool, ) -> str: st = state_ops.store(config_path) + _validate_device(cfg) specs = _input_specs(cfg) model: Any @@ -184,7 +214,7 @@ def _compile_step( try: result = aihub_jobs.submit_compile_job( model=model, - device_name=cfg.aihub.device, + device=cfg.aihub.device, input_specs=specs, target_runtime=cfg.aihub.target_runtime, options=cfg.aihub.compile_options, @@ -214,6 +244,7 @@ def _validate_step( model_id: str | None, input_name: str | None, ) -> str: + _validate_device(cfg) specs = _input_specs(cfg) resolved_model_id = _model_id_or_state(config_path, model_id) try: @@ -247,6 +278,7 @@ def _validate_step( def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str: + _validate_device(cfg) resolved_model_id = _model_id_or_state(config_path, model_id) try: result = aihub_jobs.submit_profile_job( diff --git a/src/config.py b/src/config.py index 2555a06..3da8e81 100644 --- a/src/config.py +++ b/src/config.py @@ -4,7 +4,8 @@ from typing import Any, Literal, TypedDict from mypy_boto3_s3.literals import BucketLocationConstraintType from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from qai_hub.client import Device class MlflowMode(StrEnum): @@ -81,7 +82,7 @@ class SageMakerConfig(BaseModel): class AIHubConfig(BaseModel): - device: str = "Samsung Galaxy S25 (Family)" + device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)")) target_runtime: str = "tflite" input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict) job_name: str | None = None @@ -91,6 +92,13 @@ class AIHubConfig(BaseModel): quantize_options: str | None = None output_dir: str = "build/qai-hub" + @field_validator("device", mode="before") + @classmethod + def parse_device(cls, value: Any) -> Any: + if isinstance(value, str): + return Device(value) + return value + class MlflowConfig(BaseModel): mode: MlflowMode = MlflowMode.disabled diff --git a/src/qualcomm/aihub_jobs.py b/src/qualcomm/aihub_jobs.py index 7641d15..6afda49 100644 --- a/src/qualcomm/aihub_jobs.py +++ b/src/qualcomm/aihub_jobs.py @@ -29,7 +29,7 @@ def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]: def submit_compile_job( model: Any, - device_name: str, + device: Device, input_specs: dict[str, tuple[tuple[int, ...], str]], target_runtime: str, options: str | None = None, @@ -52,7 +52,7 @@ def submit_compile_job( job = hub.submit_compile_job( model=model_arg, - device=Device(device_name), + device=device, name=job_name, input_specs=input_specs, options=compile_options, @@ -65,14 +65,14 @@ def submit_compile_job( def submit_inference_job( model_id: str, - device_name: str, + device: Device, inputs: dict[str, Any], output_dir: str | Path, job_name: str | None = None, ) -> InferenceJobResult: job = hub.submit_inference_job( model=hub.get_model(model_id), - device=Device(device_name), + device=device, inputs=_dataset_entries(inputs), name=job_name, ) @@ -84,13 +84,13 @@ def submit_inference_job( def submit_profile_job( model_id: str, - device_name: str, + device: Device, options: str | None = None, job_name: str | None = None, ) -> ProfileJobResult: job = hub.submit_profile_job( model=hub.get_model(model_id), - device=Device(device_name), + device=device, name=job_name, options=options or "", )