restructure config to use Device class directly

Also include device validation
This commit is contained in:
2026-06-04 17:28:17 -04:00
parent 71a95aa3a7
commit 6bc25dc183
5 changed files with 53 additions and 11 deletions

View File

@@ -4,7 +4,9 @@ from enum import StrEnum
from pathlib import Path
from typing import Any
import qai_hub.hub as hub
import typer
from qai_hub.client import Device
from src import state as state_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
@@ -99,6 +101,33 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo
return resolved
def _device_selector(device: Device) -> str:
parts: list[str] = []
if device.name:
parts.append(f"name={device.name!r}")
if device.os:
parts.append(f"os={device.os!r}")
if device.attributes:
parts.append(f"attributes={device.attributes!r}")
return ", ".join(parts) if parts else "empty selector"
def _validate_device(cfg: Config) -> None:
device = cfg.aihub.device
try:
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
except Exception as e:
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
raise typer.Exit(1)
if matches:
return
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
raise typer.Exit(1)
def _quantize_step(
cfg: Config,
config_path: str,
@@ -156,6 +185,7 @@ def _compile_step(
prefer_quantized: bool,
) -> str:
st = state_ops.store(config_path)
_validate_device(cfg)
specs = _input_specs(cfg)
model: Any
@@ -184,7 +214,7 @@ def _compile_step(
try:
result = aihub_jobs.submit_compile_job(
model=model,
device_name=cfg.aihub.device,
device=cfg.aihub.device,
input_specs=specs,
target_runtime=cfg.aihub.target_runtime,
options=cfg.aihub.compile_options,
@@ -214,6 +244,7 @@ def _validate_step(
model_id: str | None,
input_name: str | None,
) -> str:
_validate_device(cfg)
specs = _input_specs(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
try:
@@ -247,6 +278,7 @@ def _validate_step(
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
_validate_device(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
try:
result = aihub_jobs.submit_profile_job(