1 Commits

Author SHA1 Message Date
a8c736e28e WIP: add ai-hub metrics to MLFlow 2026-06-05 14:46:04 -04:00
8 changed files with 565 additions and 200 deletions

View File

@@ -199,6 +199,8 @@ When a step runs in the current command, `upload` passes its returned model ID d
`ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options (`--onnx-path`, `--model-s3-uri`, `--from-job`), last quantized model from state, then the last training job from local state. `ai-hub download` is separate because downloading the optimized artifact is outside the four-step Workbench upload loop.
When MLflow is enabled, AI Hub job-producing commands (`quantize`, `compile`, `validate`, `profile`, and `upload`) log AI Hub metadata to MLflow. Each command execution receives a `qc_cli.aihub_submission_id`; all steps inside one `ai-hub upload` share that submission ID. Runs are nested under the MLflow run for the resolved source model when the CLI can prove that source from local state, such as `--from-job` or a model produced by a prior tracked AI Hub step. Otherwise, AI Hub runs are standalone. `validate` also logs output summaries, and `profile` logs profile metrics plus the raw profile JSON. `ai-hub download` does not create an MLflow run because it does not submit or measure an AI Hub job.
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.
## Model lifecycle

View File

@@ -1,6 +1,3 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast
import boto3
@@ -37,38 +34,3 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

View File

@@ -1 +0,0 @@
"""Cloud provider adapters."""

View File

@@ -1,77 +0,0 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

View File

@@ -1,8 +1,11 @@
import json
from collections.abc import Mapping, Sequence
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import StrEnum
from pathlib import Path
from typing import Any
from uuid import uuid4
import qai_hub.hub as hub
import typer
@@ -13,6 +16,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config
from src.qualcomm import aihub_jobs
from src.qualcomm.artifacts import resolve_onnx
from src.tracking.mlflow import AIHubSourceProvenance, AIHubStepRecord, MlflowTracker, Tracker
app = typer.Typer(help="Quantize, compile, validate, profile, and download models with Qualcomm AI Hub")
@@ -30,6 +34,16 @@ class UploadStep(StrEnum):
profile = "profile"
@dataclass(frozen=True)
class AIHubStepResult:
job: Any
job_id: str
model_id: str | None = None
output_dir: Path | None = None
outputs: Mapping[str, Any] | None = None
profile: Mapping[str, Any] | None = None
def _input_specs(cfg: Config) -> dict[str, tuple[tuple[int, ...], str]]:
specs = {name: (tuple(shape), dtype) for name, (shape, dtype) in cfg.aihub.input_specs.items()}
if not specs:
@@ -112,6 +126,116 @@ def _device_selector(device: Device) -> str:
return ", ".join(parts) if parts else "empty selector"
def _submission_id() -> str:
return f"{datetime.now().strftime('%Y%m%d-%H%M%S')}-{uuid4().hex[:8]}"
def _tracker(cfg: Config) -> Tracker:
try:
return MlflowTracker.from_config(cfg)
except Exception as e:
CONSOLE.print(f"[red]MLflow setup failed: {e}[/red]")
raise typer.Exit(1)
def _training_parent_run_id(config_path: str, training_job: str | None) -> str | None:
if not training_job:
return None
run_id = state_ops.store(config_path).get_training_job(training_job).get("mlflow_run_id")
return str(run_id) if run_id else None
def _source_to_state(source: AIHubSourceProvenance) -> dict[str, Any]:
return {key: value for key, value in asdict(source).items() if value is not None}
def _source_from_state(value: Mapping[str, Any]) -> AIHubSourceProvenance:
return AIHubSourceProvenance(
kind=str(value.get("kind", "aihub_model")),
parent_run_id=str(value["parent_run_id"]) if value.get("parent_run_id") else None,
uri=str(value["uri"]) if value.get("uri") else None,
path=str(value["path"]) if value.get("path") else None,
aihub_model_id=str(value["aihub_model_id"]) if value.get("aihub_model_id") else None,
training_job=str(value["training_job"]) if value.get("training_job") else None,
)
def _source_for_aihub_model(config_path: str, model_id: str) -> AIHubSourceProvenance:
stored = state_ops.store(config_path).get_aihub_model_provenance(model_id)
if stored:
return _source_from_state(stored)
return AIHubSourceProvenance(kind="aihub_model", aihub_model_id=model_id)
def _source_for_resolved_onnx(
config_path: str,
*,
resolved_path: Path,
model_artifact: str | None,
from_job: str | None,
model_s3_uri: str | None,
onnx_path: str | None,
implicit_training_job: str | None,
implicit_model_artifact: str | None,
) -> AIHubSourceProvenance:
if onnx_path and Path(onnx_path).exists() and not from_job and not model_s3_uri:
return AIHubSourceProvenance(kind="local_onnx", path=str(resolved_path))
training_job = from_job
if not training_job and model_artifact and implicit_model_artifact and model_artifact == implicit_model_artifact:
training_job = implicit_training_job
if not training_job and not model_s3_uri and not onnx_path:
training_job = implicit_training_job
return AIHubSourceProvenance(
kind="sagemaker_model_artifact" if model_artifact else "local_onnx",
parent_run_id=_training_parent_run_id(config_path, training_job),
uri=model_artifact,
path=str(resolved_path) if not model_artifact else None,
training_job=training_job,
)
def _model_id_or_state_with_source(
config_path: str,
model_id: str | None,
*,
quantized: bool = False,
) -> tuple[str, AIHubSourceProvenance]:
resolved_model_id = _model_id_or_state(config_path, model_id, quantized=quantized)
return resolved_model_id, _source_for_aihub_model(config_path, resolved_model_id)
def _record_step(
cfg: Config,
tracker: Tracker,
*,
result: AIHubStepResult,
source: AIHubSourceProvenance,
step: str,
submission_id: str,
command: str,
options: str | None = None,
) -> None:
tracker.record_aihub_step(
AIHubStepRecord(
step=step,
submission_id=submission_id,
command=command,
source=source,
job=result.job,
job_id=result.job_id,
model_id=result.model_id,
target_runtime=cfg.aihub.target_runtime,
device=_device_selector(cfg.aihub.device),
options=options,
output_dir=result.output_dir,
outputs=result.outputs,
profile=result.profile,
)
)
def _validate_device(cfg: Config) -> None:
device = cfg.aihub.device
try:
@@ -135,23 +259,38 @@ def _quantize_step(
from_job: str | None,
model_s3_uri: str | None,
onnx_path: str | None,
) -> str:
tracker: Tracker,
submission_id: str,
) -> AIHubStepResult:
st = state_ops.store(config_path)
specs = _input_specs(cfg)
implicit_training_job = st.get_last_training_job()
implicit_model_artifact = st.get_last_model_artifact()
try:
resolved = resolve_onnx(
cfg=cfg,
output_dir=cfg.aihub.output_dir,
from_job=from_job,
model_s3_uri=model_s3_uri or st.get_last_model_artifact(),
model_s3_uri=model_s3_uri or implicit_model_artifact,
onnx_path=onnx_path,
last_training_job=st.get_last_training_job(),
last_training_job=implicit_training_job,
)
calibration_data = _load_calibration(calibration_path, specs)
except (FileNotFoundError, ValueError) as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
source = _source_for_resolved_onnx(
config_path,
resolved_path=resolved.onnx_path,
model_artifact=resolved.model_artifact,
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
implicit_training_job=implicit_training_job,
implicit_model_artifact=implicit_model_artifact,
)
try:
result = aihub_jobs.submit_quantize_job(
resolved.onnx_path,
@@ -169,9 +308,25 @@ def _quantize_step(
last_quantize_job_id=result["job_id"],
last_quantized_model_id=result["model_id"],
)
st.update_aihub_model_provenance(str(result["model_id"]), _source_to_state(source))
step_result = AIHubStepResult(
job=result["job"],
job_id=str(result["job_id"]),
model_id=str(result["model_id"]),
)
_record_step(
cfg,
tracker,
result=step_result,
source=source,
step="quantize",
submission_id=submission_id,
command="ai-hub quantize",
options=cfg.aihub.quantize_options,
)
CONSOLE.print(f"[green]✓[/green] Quantize job: [bold]{result['job_id']}[/bold]")
CONSOLE.print(f"[green]✓[/green] Quantized model: [bold]{result['model_id']}[/bold]")
return str(result["model_id"])
return step_result
def _compile_step(
@@ -183,19 +338,25 @@ def _compile_step(
onnx_path: str | None,
*,
prefer_quantized: bool,
) -> str:
tracker: Tracker,
submission_id: str,
) -> AIHubStepResult:
st = state_ops.store(config_path)
_validate_device(cfg)
specs = _input_specs(cfg)
model: Any
model_artifact: str | None = None
source: AIHubSourceProvenance
has_explicit_source = bool(from_job or model_s3_uri or onnx_path)
if model_id:
model = model_id
source = _source_for_aihub_model(config_path, model_id)
elif prefer_quantized and not has_explicit_source and st.get_last_quantized_model_id():
model = st.get_last_quantized_model_id()
source = _source_for_aihub_model(config_path, str(model))
else:
implicit_training_job = st.get_last_training_job()
try:
resolved = resolve_onnx(
cfg=cfg,
@@ -203,13 +364,23 @@ def _compile_step(
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
last_training_job=st.get_last_training_job(),
last_training_job=implicit_training_job,
)
except (FileNotFoundError, ValueError) as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
model = resolved.onnx_path
model_artifact = resolved.model_artifact
source = _source_for_resolved_onnx(
config_path,
resolved_path=resolved.onnx_path,
model_artifact=resolved.model_artifact,
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
implicit_training_job=implicit_training_job,
implicit_model_artifact=st.get_last_model_artifact(),
)
try:
result = aihub_jobs.submit_compile_job(
@@ -232,9 +403,25 @@ def _compile_step(
if model_artifact:
updates["last_model_artifact"] = model_artifact
st.update(**updates)
st.update_aihub_model_provenance(str(result["model_id"]), _source_to_state(source))
step_result = AIHubStepResult(
job=result["job"],
job_id=str(result["job_id"]),
model_id=str(result["model_id"]),
)
_record_step(
cfg,
tracker,
result=step_result,
source=source,
step="compile",
submission_id=submission_id,
command="ai-hub compile",
options=cfg.aihub.compile_options,
)
CONSOLE.print(f"[green]✓[/green] Compile job: [bold]{result['job_id']}[/bold]")
CONSOLE.print(f"[green]✓[/green] Compiled model: [bold]{result['model_id']}[/bold]")
return str(result["model_id"])
return step_result
def _validate_step(
@@ -243,10 +430,12 @@ def _validate_step(
input_file: Path,
model_id: str | None,
input_name: str | None,
) -> str:
tracker: Tracker,
submission_id: str,
) -> AIHubStepResult:
_validate_device(cfg)
specs = _input_specs(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
resolved_model_id, source = _model_id_or_state_with_source(config_path, model_id)
try:
inputs = _load_inputs(input_file, specs, input_name)
except (FileNotFoundError, ValueError) as e:
@@ -268,18 +457,40 @@ def _validate_step(
raise typer.Exit(1)
state_ops.store(config_path).update(last_inference_job_id=result["job_id"])
CONSOLE.print(f"[green]✓[/green] Inference job: [bold]{result['job_id']}[/bold]")
outputs = result.get("outputs")
step_result = AIHubStepResult(
job=result["job"],
job_id=str(result["job_id"]),
model_id=resolved_model_id,
output_dir=out_dir,
outputs=outputs if isinstance(outputs, Mapping) else None,
)
_record_step(
cfg,
tracker,
result=step_result,
source=source,
step="validate",
submission_id=submission_id,
command="ai-hub validate",
)
CONSOLE.print(f"[green]✓[/green] Inference job: [bold]{result['job_id']}[/bold]")
if isinstance(outputs, dict):
for name, value in outputs.items():
CONSOLE.print(f" {name}: shape={getattr(value, 'shape', '?')}")
CONSOLE.print(f"Outputs: [cyan]{out_dir}[/cyan]")
return str(result["job_id"])
return step_result
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
def _profile_step(
cfg: Config,
config_path: str,
model_id: str | None,
tracker: Tracker,
submission_id: str,
) -> AIHubStepResult:
_validate_device(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
resolved_model_id, source = _model_id_or_state_with_source(config_path, model_id)
try:
result = aihub_jobs.submit_profile_job(
resolved_model_id,
@@ -290,9 +501,41 @@ def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
except Exception as e:
CONSOLE.print(f"[red]AI Hub profile failed: {e}[/red]")
raise typer.Exit(1)
run = datetime.now().strftime("%Y%m%d-%H%M%S")
out_dir = Path(cfg.aihub.output_dir) / run / "profile"
try:
out_dir.mkdir(parents=True, exist_ok=True)
profile_data = result["job"].download_profile()
if isinstance(profile_data, Mapping):
(out_dir / "profile.json").write_text(json.dumps(profile_data, indent=2), encoding="utf-8")
else:
profile_data = {}
except Exception as e:
CONSOLE.print(f"[red]AI Hub profile download failed: {e}[/red]")
raise typer.Exit(1)
state_ops.store(config_path).update(last_profile_job_id=result["job_id"])
step_result = AIHubStepResult(
job=result["job"],
job_id=str(result["job_id"]),
model_id=resolved_model_id,
output_dir=out_dir,
profile=profile_data,
)
_record_step(
cfg,
tracker,
result=step_result,
source=source,
step="profile",
submission_id=submission_id,
command="ai-hub profile",
options=cfg.aihub.profile_options,
)
CONSOLE.print(f"[green]✓[/green] Profile job: [bold]{result['job_id']}[/bold]")
return str(result["job_id"])
CONSOLE.print(f"Profile: [cyan]{out_dir}[/cyan]")
return step_result
@app.command()
@@ -307,7 +550,16 @@ def quantize(
) -> None:
"""Quantize an ONNX model to INT8."""
cfg = load_cfg(config)
_quantize_step(cfg, config, calibration_path, from_job, model_s3_uri, onnx_path)
_quantize_step(
cfg,
config,
calibration_path,
from_job,
model_s3_uri,
onnx_path,
_tracker(cfg),
_submission_id(),
)
@app.command()
@@ -322,7 +574,17 @@ def compile(
) -> None:
"""Compile a model for the configured Qualcomm AI Hub target."""
cfg = load_cfg(config)
_compile_step(cfg, config, model_id, from_job, model_s3_uri, onnx_path, prefer_quantized=True)
_compile_step(
cfg,
config,
model_id,
from_job,
model_s3_uri,
onnx_path,
prefer_quantized=True,
tracker=_tracker(cfg),
submission_id=_submission_id(),
)
@app.command()
@@ -334,7 +596,7 @@ def validate(
) -> None:
"""Run an AI Hub inference job using sample inputs."""
cfg = load_cfg(config)
_validate_step(cfg, config, input_file, model_id, input_name)
_validate_step(cfg, config, input_file, model_id, input_name, _tracker(cfg), _submission_id())
@app.command()
@@ -344,7 +606,7 @@ def profile(
) -> None:
"""Profile a compiled model on the configured AI Hub device."""
cfg = load_cfg(config)
_profile_step(cfg, config, model_id)
_profile_step(cfg, config, model_id, _tracker(cfg), _submission_id())
@app.command()
@@ -364,13 +626,25 @@ def upload(
cfg = load_cfg(config)
steps = [UploadStep.quantize, UploadStep.compile, UploadStep.validate, UploadStep.profile]
selected = steps[steps.index(from_step) :]
tracker = _tracker(cfg)
submission_id = _submission_id()
quantized_model_id: str | None = None
compiled_model_id: str | None = None
if UploadStep.quantize in selected:
quantized_model_id = _quantize_step(cfg, config, calibration_path, from_job, model_s3_uri, onnx_path)
quantized = _quantize_step(
cfg,
config,
calibration_path,
from_job,
model_s3_uri,
onnx_path,
tracker,
submission_id,
)
quantized_model_id = quantized.model_id
if UploadStep.compile in selected:
compiled_model_id = _compile_step(
compiled = _compile_step(
cfg,
config,
model_id=quantized_model_id,
@@ -378,11 +652,14 @@ def upload(
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
prefer_quantized=True,
tracker=tracker,
submission_id=submission_id,
)
compiled_model_id = compiled.model_id
if UploadStep.validate in selected:
_validate_step(cfg, config, input_file, compiled_model_id, input_name)
_validate_step(cfg, config, input_file, compiled_model_id, input_name, tracker, submission_id)
if UploadStep.profile in selected:
_profile_step(cfg, config, compiled_model_id)
_profile_step(cfg, config, compiled_model_id, tracker, submission_id)
@app.command()

View File

@@ -67,6 +67,18 @@ class CliStateStore:
def set_latest_experiment_model_version(self, version: str) -> None:
self.update(latest_experiment_model_version=version)
def get_aihub_model_provenance(self, model_id: str) -> dict[str, Any]:
provenance = self._aihub_model_provenance(self.read())
value = provenance.get(model_id, {})
return dict(value) if isinstance(value, dict) else {}
def update_aihub_model_provenance(self, model_id: str, provenance: dict[str, Any]) -> None:
state = self.read()
model_provenance = self._aihub_model_provenance(state)
model_provenance[model_id] = provenance
state["aihub_model_provenance"] = model_provenance
self._write(state)
def _write(self, state: dict[str, Any]) -> None:
with open(self.path, "w") as f:
json.dump(state, f, indent=2)
@@ -75,6 +87,10 @@ class CliStateStore:
value = state.get("training_jobs", {})
return dict(value) if isinstance(value, dict) else {}
def _aihub_model_provenance(self, state: dict[str, Any]) -> dict[str, Any]:
value = state.get("aihub_model_provenance", {})
return dict(value) if isinstance(value, dict) else {}
def store(config_path: str) -> CliStateStore:
config_dir = str(Path(config_path).parent)

View File

@@ -1,3 +1,3 @@
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
from src.tracking.mlflow import AIHubSourceProvenance, AIHubStepRecord, MlflowTracker, NoopTracker, Tracker
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
__all__ = ["AIHubSourceProvenance", "AIHubStepRecord", "MlflowTracker", "NoopTracker", "Tracker"]

View File

@@ -1,11 +1,14 @@
import os
import re
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol
import mlflow
from mlflow.tracking import MlflowClient
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
from src.aws import mlflow as aws_mlflow
from src.config import Config, MlflowMode
@@ -14,6 +17,35 @@ class Tracker(Protocol):
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ...
def record_aihub_step(self, record: "AIHubStepRecord") -> str | None: ...
@dataclass(frozen=True)
class AIHubSourceProvenance:
kind: str
parent_run_id: str | None = None
uri: str | None = None
path: str | None = None
aihub_model_id: str | None = None
training_job: str | None = None
@dataclass(frozen=True)
class AIHubStepRecord:
step: str
submission_id: str
command: str
source: AIHubSourceProvenance
job: Any | None = None
job_id: str | None = None
model_id: str | None = None
target_runtime: str | None = None
device: str | None = None
options: str | None = None
output_dir: str | Path | None = None
outputs: Mapping[str, Any] | None = None
profile: Mapping[str, Any] | None = None
@dataclass(frozen=True)
class NoopTracker:
@@ -23,6 +55,9 @@ class NoopTracker:
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
return None
def record_aihub_step(self, record: AIHubStepRecord) -> str | None:
return None
@dataclass(frozen=True)
class MlflowTracker:
@@ -30,7 +65,6 @@ class MlflowTracker:
experiment_name: str
registered_model_name: str
register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod
def from_config(cls, cfg: Config) -> Tracker:
@@ -43,10 +77,11 @@ class MlflowTracker:
if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_backend = mlflow_tracking_backend_from_config(cfg)
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
with tracking_backend.auth_env():
tracking_uri = aws_mlflow.get_tracking_server_arn(
cfg.aws.region,
cfg.aws.profile,
tracking_server_name,
)
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)
@@ -55,30 +90,34 @@ class MlflowTracker:
experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
)
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id)
self._log_params(
self.tracking_backend.training_run_params(
training_job,
region=region,
profile=profile,
role_arn=role_arn,
)
)
params = {
"aws.region": region,
"aws.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.source": "sagemaker",
"qc_cli.command": "train start",
**self.tracking_backend.training_run_tags(training_job),
"sagemaker.job_name": training_job.job_name,
}
)
mlflow.end_run()
@@ -88,9 +127,16 @@ class MlflowTracker:
if not run_id:
return None
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_params(
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status")
@@ -110,8 +156,8 @@ class MlflowTracker:
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
**self.tracking_backend.model_version_tags(training_job_status),
"qc_cli.source": "sagemaker",
"sagemaker.job_name": training_job_status.name,
},
)
version_number = str(version.version)
@@ -120,6 +166,21 @@ class MlflowTracker:
mlflow.set_tag("qc_cli.registered_model_version", version_number)
return version_number
def record_aihub_step(self, record: AIHubStepRecord) -> str | None:
run_name = f"ai-hub {record.step}"
if record.source.parent_run_id:
with mlflow.start_run(run_id=record.source.parent_run_id):
child = mlflow.start_run(run_name=run_name, nested=True)
try:
self._log_aihub_record(record)
return str(child.info.run_id)
finally:
mlflow.end_run()
with mlflow.start_run(run_name=run_name) as run:
self._log_aihub_record(record)
return str(run.info.run_id)
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
if cleaned:
@@ -140,3 +201,128 @@ class MlflowTracker:
client.get_registered_model(name)
except Exception:
client.create_registered_model(name)
def _log_aihub_record(self, record: AIHubStepRecord) -> None:
status = self._job_status(record.job)
job_id = record.job_id or self._job_attr(record.job, "job_id")
self._log_params(
{
"aihub.step": record.step,
"aihub.submission_id": record.submission_id,
"aihub.job_id": job_id,
"aihub.job_name": self._job_attr(record.job, "name"),
"aihub.job_type": self._job_attr(record.job, "job_type"),
"aihub.job_url": self._job_attr(record.job, "url"),
"aihub.model_id": record.model_id,
"aihub.target_runtime": record.target_runtime,
"aihub.device": record.device,
"aihub.options": record.options or self._job_attr(record.job, "options"),
"aihub.status": status.get("code"),
"aihub.failure_reason": status.get("message"),
"aihub.output_dir": record.output_dir,
"qc_cli.source_model.kind": record.source.kind,
"qc_cli.source_model.uri": record.source.uri,
"qc_cli.source_model.path": record.source.path,
"qc_cli.source_model.aihub_model_id": record.source.aihub_model_id,
"qc_cli.source_training_job": record.source.training_job,
"qc_cli.parent_mlflow_run_id": record.source.parent_run_id,
}
)
mlflow.set_tags(
{
"qc_cli.source": "ai_hub",
"qc_cli.stage": record.step,
"qc_cli.command": record.command,
"qc_cli.aihub_submission_id": record.submission_id,
}
)
self._log_output_stats(record.outputs)
self._log_profile(record.profile)
if record.output_dir:
output_dir = Path(record.output_dir)
if output_dir.exists() and output_dir.is_dir():
mlflow.log_artifacts(str(output_dir), artifact_path=f"aihub/{record.step}")
def _log_output_stats(self, outputs: Mapping[str, Any] | None) -> None:
if not outputs:
return
import numpy as np
params: dict[str, Any] = {}
metrics: dict[str, float] = {}
for name, value in outputs.items():
safe_name = self._metric_name(name)
arr = np.asarray(value)
params[f"aihub.inference.output.{safe_name}.shape"] = list(arr.shape)
params[f"aihub.inference.output.{safe_name}.dtype"] = str(arr.dtype)
metrics[f"aihub.inference.output.{safe_name}.count"] = float(arr.size)
if arr.size == 0 or not np.issubdtype(arr.dtype, np.number):
continue
numeric = arr.astype(float, copy=False)
finite = numeric[np.isfinite(numeric)]
metrics[f"aihub.inference.output.{safe_name}.nan_count"] = float(np.isnan(numeric).sum())
metrics[f"aihub.inference.output.{safe_name}.inf_count"] = float(np.isinf(numeric).sum())
if finite.size == 0:
continue
metrics[f"aihub.inference.output.{safe_name}.min"] = float(finite.min())
metrics[f"aihub.inference.output.{safe_name}.max"] = float(finite.max())
metrics[f"aihub.inference.output.{safe_name}.mean"] = float(finite.mean())
metrics[f"aihub.inference.output.{safe_name}.std"] = float(finite.std())
metrics[f"aihub.inference.output.{safe_name}.l1_norm"] = float(np.linalg.norm(finite, ord=1))
metrics[f"aihub.inference.output.{safe_name}.l2_norm"] = float(np.linalg.norm(finite, ord=2))
self._log_params(params)
if metrics:
mlflow.log_metrics(metrics)
def _log_profile(self, profile: Mapping[str, Any] | None) -> None:
if not profile:
return
mlflow.log_dict(dict(profile), "aihub/profile.json")
metrics = {
f"aihub.profile.{self._metric_name(path)}": float(value)
for path, value in self._flatten_numeric(profile).items()
}
if metrics:
mlflow.log_metrics(metrics)
def _flatten_numeric(self, value: Any, prefix: str = "") -> dict[str, float]:
if isinstance(value, Mapping):
flattened: dict[str, float] = {}
for key, item in value.items():
child_prefix = f"{prefix}.{key}" if prefix else str(key)
flattened.update(self._flatten_numeric(item, child_prefix))
return flattened
if isinstance(value, list | tuple):
flattened = {}
for index, item in enumerate(value):
child_prefix = f"{prefix}.{index}" if prefix else str(index)
flattened.update(self._flatten_numeric(item, child_prefix))
return flattened
if isinstance(value, bool):
return {}
if isinstance(value, int | float):
return {prefix: float(value)}
return {}
def _job_status(self, job: Any | None) -> dict[str, Any]:
if job is None or not hasattr(job, "get_status"):
return {}
status = job.get_status()
return {
"code": getattr(status, "code", None),
"message": getattr(status, "message", None),
}
def _job_attr(self, job: Any | None, name: str) -> Any:
if job is None:
return None
try:
return getattr(job, name)
except Exception:
return None
def _metric_name(self, value: str) -> str:
return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(value)).strip("._") or "value"