restructure config to use Device class directly

Also include device validation
This commit is contained in:
2026-06-04 17:28:17 -04:00
parent 71a95aa3a7
commit 6bc25dc183
5 changed files with 53 additions and 11 deletions

View File

@@ -4,7 +4,9 @@ from enum import StrEnum
from pathlib import Path
from typing import Any
import qai_hub.hub as hub
import typer
from qai_hub.client import Device
from src import state as state_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
@@ -99,6 +101,33 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo
return resolved
def _device_selector(device: Device) -> str:
parts: list[str] = []
if device.name:
parts.append(f"name={device.name!r}")
if device.os:
parts.append(f"os={device.os!r}")
if device.attributes:
parts.append(f"attributes={device.attributes!r}")
return ", ".join(parts) if parts else "empty selector"
def _validate_device(cfg: Config) -> None:
device = cfg.aihub.device
try:
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
except Exception as e:
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
raise typer.Exit(1)
if matches:
return
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
raise typer.Exit(1)
def _quantize_step(
cfg: Config,
config_path: str,
@@ -156,6 +185,7 @@ def _compile_step(
prefer_quantized: bool,
) -> str:
st = state_ops.store(config_path)
_validate_device(cfg)
specs = _input_specs(cfg)
model: Any
@@ -184,7 +214,7 @@ def _compile_step(
try:
result = aihub_jobs.submit_compile_job(
model=model,
device_name=cfg.aihub.device,
device=cfg.aihub.device,
input_specs=specs,
target_runtime=cfg.aihub.target_runtime,
options=cfg.aihub.compile_options,
@@ -214,6 +244,7 @@ def _validate_step(
model_id: str | None,
input_name: str | None,
) -> str:
_validate_device(cfg)
specs = _input_specs(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
try:
@@ -247,6 +278,7 @@ def _validate_step(
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
_validate_device(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
try:
result = aihub_jobs.submit_profile_job(

View File

@@ -4,7 +4,8 @@ from typing import Any, Literal, TypedDict
from mypy_boto3_s3.literals import BucketLocationConstraintType
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from qai_hub.client import Device
class MlflowMode(StrEnum):
@@ -81,7 +82,7 @@ class SageMakerConfig(BaseModel):
class AIHubConfig(BaseModel):
device: str = "Samsung Galaxy S25 (Family)"
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)"))
target_runtime: str = "tflite"
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
job_name: str | None = None
@@ -91,6 +92,13 @@ class AIHubConfig(BaseModel):
quantize_options: str | None = None
output_dir: str = "build/qai-hub"
@field_validator("device", mode="before")
@classmethod
def parse_device(cls, value: Any) -> Any:
if isinstance(value, str):
return Device(value)
return value
class MlflowConfig(BaseModel):
mode: MlflowMode = MlflowMode.disabled

View File

@@ -29,7 +29,7 @@ def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
def submit_compile_job(
model: Any,
device_name: str,
device: Device,
input_specs: dict[str, tuple[tuple[int, ...], str]],
target_runtime: str,
options: str | None = None,
@@ -52,7 +52,7 @@ def submit_compile_job(
job = hub.submit_compile_job(
model=model_arg,
device=Device(device_name),
device=device,
name=job_name,
input_specs=input_specs,
options=compile_options,
@@ -65,14 +65,14 @@ def submit_compile_job(
def submit_inference_job(
model_id: str,
device_name: str,
device: Device,
inputs: dict[str, Any],
output_dir: str | Path,
job_name: str | None = None,
) -> InferenceJobResult:
job = hub.submit_inference_job(
model=hub.get_model(model_id),
device=Device(device_name),
device=device,
inputs=_dataset_entries(inputs),
name=job_name,
)
@@ -84,13 +84,13 @@ def submit_inference_job(
def submit_profile_job(
model_id: str,
device_name: str,
device: Device,
options: str | None = None,
job_name: str | None = None,
) -> ProfileJobResult:
job = hub.submit_profile_job(
model=hub.get_model(model_id),
device=Device(device_name),
device=device,
name=job_name,
options=options or "",
)