restructure config to use Device class directly
Also include device validation
This commit is contained in:
@@ -67,7 +67,8 @@ sagemaker:
|
||||
hyperparameters: {}
|
||||
|
||||
aihub:
|
||||
device: Samsung Galaxy S25 (Family)
|
||||
device:
|
||||
name: Samsung Galaxy S25 (Family)
|
||||
target_runtime: tflite
|
||||
input_specs: {} # Required before running qc-cli ai-hub commands
|
||||
job_name: null # Optional prefix for AI Hub Workbench jobs
|
||||
|
||||
@@ -28,7 +28,8 @@ Your `config.yaml` must include AI Hub settings:
|
||||
|
||||
```yaml
|
||||
aihub:
|
||||
device: Samsung Galaxy S25 (Family)
|
||||
device:
|
||||
name: Samsung Galaxy S25 (Family)
|
||||
target_runtime: tflite
|
||||
input_specs:
|
||||
input: [[1, 3, 160, 160], float32]
|
||||
|
||||
@@ -4,7 +4,9 @@ from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import qai_hub.hub as hub
|
||||
import typer
|
||||
from qai_hub.client import Device
|
||||
|
||||
from src import state as state_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
@@ -99,6 +101,33 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo
|
||||
return resolved
|
||||
|
||||
|
||||
def _device_selector(device: Device) -> str:
|
||||
parts: list[str] = []
|
||||
if device.name:
|
||||
parts.append(f"name={device.name!r}")
|
||||
if device.os:
|
||||
parts.append(f"os={device.os!r}")
|
||||
if device.attributes:
|
||||
parts.append(f"attributes={device.attributes!r}")
|
||||
return ", ".join(parts) if parts else "empty selector"
|
||||
|
||||
|
||||
def _validate_device(cfg: Config) -> None:
|
||||
device = cfg.aihub.device
|
||||
try:
|
||||
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if matches:
|
||||
return
|
||||
|
||||
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
|
||||
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _quantize_step(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
@@ -156,6 +185,7 @@ def _compile_step(
|
||||
prefer_quantized: bool,
|
||||
) -> str:
|
||||
st = state_ops.store(config_path)
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
|
||||
model: Any
|
||||
@@ -184,7 +214,7 @@ def _compile_step(
|
||||
try:
|
||||
result = aihub_jobs.submit_compile_job(
|
||||
model=model,
|
||||
device_name=cfg.aihub.device,
|
||||
device=cfg.aihub.device,
|
||||
input_specs=specs,
|
||||
target_runtime=cfg.aihub.target_runtime,
|
||||
options=cfg.aihub.compile_options,
|
||||
@@ -214,6 +244,7 @@ def _validate_step(
|
||||
model_id: str | None,
|
||||
input_name: str | None,
|
||||
) -> str:
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||
try:
|
||||
@@ -247,6 +278,7 @@ def _validate_step(
|
||||
|
||||
|
||||
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
||||
_validate_device(cfg)
|
||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||
try:
|
||||
result = aihub_jobs.submit_profile_job(
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Any, Literal, TypedDict
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from qai_hub.client import Device
|
||||
|
||||
|
||||
class MlflowMode(StrEnum):
|
||||
@@ -81,7 +82,7 @@ class SageMakerConfig(BaseModel):
|
||||
|
||||
|
||||
class AIHubConfig(BaseModel):
|
||||
device: str = "Samsung Galaxy S25 (Family)"
|
||||
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)"))
|
||||
target_runtime: str = "tflite"
|
||||
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
|
||||
job_name: str | None = None
|
||||
@@ -91,6 +92,13 @@ class AIHubConfig(BaseModel):
|
||||
quantize_options: str | None = None
|
||||
output_dir: str = "build/qai-hub"
|
||||
|
||||
@field_validator("device", mode="before")
|
||||
@classmethod
|
||||
def parse_device(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return Device(value)
|
||||
return value
|
||||
|
||||
|
||||
class MlflowConfig(BaseModel):
|
||||
mode: MlflowMode = MlflowMode.disabled
|
||||
|
||||
@@ -29,7 +29,7 @@ def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
|
||||
|
||||
def submit_compile_job(
|
||||
model: Any,
|
||||
device_name: str,
|
||||
device: Device,
|
||||
input_specs: dict[str, tuple[tuple[int, ...], str]],
|
||||
target_runtime: str,
|
||||
options: str | None = None,
|
||||
@@ -52,7 +52,7 @@ def submit_compile_job(
|
||||
|
||||
job = hub.submit_compile_job(
|
||||
model=model_arg,
|
||||
device=Device(device_name),
|
||||
device=device,
|
||||
name=job_name,
|
||||
input_specs=input_specs,
|
||||
options=compile_options,
|
||||
@@ -65,14 +65,14 @@ def submit_compile_job(
|
||||
|
||||
def submit_inference_job(
|
||||
model_id: str,
|
||||
device_name: str,
|
||||
device: Device,
|
||||
inputs: dict[str, Any],
|
||||
output_dir: str | Path,
|
||||
job_name: str | None = None,
|
||||
) -> InferenceJobResult:
|
||||
job = hub.submit_inference_job(
|
||||
model=hub.get_model(model_id),
|
||||
device=Device(device_name),
|
||||
device=device,
|
||||
inputs=_dataset_entries(inputs),
|
||||
name=job_name,
|
||||
)
|
||||
@@ -84,13 +84,13 @@ def submit_inference_job(
|
||||
|
||||
def submit_profile_job(
|
||||
model_id: str,
|
||||
device_name: str,
|
||||
device: Device,
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
) -> ProfileJobResult:
|
||||
job = hub.submit_profile_job(
|
||||
model=hub.get_model(model_id),
|
||||
device=Device(device_name),
|
||||
device=device,
|
||||
name=job_name,
|
||||
options=options or "",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user