restructure config to use Device class directly
Also include device validation
This commit is contained in:
@@ -4,7 +4,9 @@ from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import qai_hub.hub as hub
|
||||
import typer
|
||||
from qai_hub.client import Device
|
||||
|
||||
from src import state as state_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
@@ -99,6 +101,33 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo
|
||||
return resolved
|
||||
|
||||
|
||||
def _device_selector(device: Device) -> str:
|
||||
parts: list[str] = []
|
||||
if device.name:
|
||||
parts.append(f"name={device.name!r}")
|
||||
if device.os:
|
||||
parts.append(f"os={device.os!r}")
|
||||
if device.attributes:
|
||||
parts.append(f"attributes={device.attributes!r}")
|
||||
return ", ".join(parts) if parts else "empty selector"
|
||||
|
||||
|
||||
def _validate_device(cfg: Config) -> None:
|
||||
device = cfg.aihub.device
|
||||
try:
|
||||
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if matches:
|
||||
return
|
||||
|
||||
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
|
||||
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _quantize_step(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
@@ -156,6 +185,7 @@ def _compile_step(
|
||||
prefer_quantized: bool,
|
||||
) -> str:
|
||||
st = state_ops.store(config_path)
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
|
||||
model: Any
|
||||
@@ -184,7 +214,7 @@ def _compile_step(
|
||||
try:
|
||||
result = aihub_jobs.submit_compile_job(
|
||||
model=model,
|
||||
device_name=cfg.aihub.device,
|
||||
device=cfg.aihub.device,
|
||||
input_specs=specs,
|
||||
target_runtime=cfg.aihub.target_runtime,
|
||||
options=cfg.aihub.compile_options,
|
||||
@@ -214,6 +244,7 @@ def _validate_step(
|
||||
model_id: str | None,
|
||||
input_name: str | None,
|
||||
) -> str:
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||
try:
|
||||
@@ -247,6 +278,7 @@ def _validate_step(
|
||||
|
||||
|
||||
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
||||
_validate_device(cfg)
|
||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||
try:
|
||||
result = aihub_jobs.submit_profile_job(
|
||||
|
||||
Reference in New Issue
Block a user