This commit is contained in:
2026-06-12 11:42:26 -04:00
parent 522ddc74e2
commit 2d4d377051
8 changed files with 390 additions and 38 deletions

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from pathlib import Path
@@ -21,6 +22,8 @@ _STATUS_COLOR = {
"Stopping": "yellow",
"Stopped": "dim",
}
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
DEFAULT_POLL_INTERVAL_SECONDS = 30
def _tracker(cfg):
@@ -48,6 +51,57 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
def _finalize_terminal_job(
*,
config_path: str,
cfg: Config,
status: sm_ops.TrainingJobStatus,
command: str,
) -> None:
if status.status not in _TERMINAL_STATUSES:
return
st = state_ops.store(config_path)
job_state = st.get_training_job(status.name)
run_id = job_state.get("mlflow_run_id")
if not run_id or job_state.get("mlflow_finalized_status"):
return
tracker = _tracker(cfg)
result = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command=command,
)
updates = {"mlflow_finalized_status": status.status}
if result.registered_model_version:
updates["registered_model_version"] = result.registered_model_version
st.update_training_job(status.name, **updates)
for warning in result.warnings:
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
if result.registered_model_version:
st.set_latest_experiment_model_version(result.registered_model_version)
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
@app.command()
def start(config: str = CONFIG_OPT) -> None:
"""Submit a SageMaker training job."""
@@ -123,37 +177,65 @@ def status(
raise typer.Exit(1)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
_print_training_status(status)
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
job_state = st.get_training_job(job_name)
run_id = job_state.get("mlflow_run_id")
already_registered = job_state.get("registered_model_version")
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
tracker = _tracker(cfg)
version = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
)
updates = {"mlflow_finalized_status": status.status}
if version:
updates["registered_model_version"] = version
st.update_training_job(job_name, **updates)
if version:
st.set_latest_experiment_model_version(version)
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command()
def wait(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between SageMaker status checks",
),
config: str = CONFIG_OPT,
) -> None:
"""Wait for a training job and finalize its MLflow run."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config,
cfg=cfg,
status=training_status,
command="train wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),

View File

@@ -1,3 +1,3 @@
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]

93
src/tracking/metrics.py Normal file
View File

@@ -0,0 +1,93 @@
import json
import math
import tarfile
from dataclasses import dataclass
from pathlib import PurePosixPath
from typing import Any
METRICS_ARTIFACT_NAME = "training_metrics.json"
METRICS_SCHEMA_VERSION = 1
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
@dataclass(frozen=True)
class MetricStep:
step: int
metrics: dict[str, float]
@dataclass(frozen=True)
class TrainingMetrics:
steps: list[MetricStep]
summary: dict[str, float]
raw: dict[str, Any]
def parse_training_metrics(data: bytes) -> TrainingMetrics:
try:
value = json.loads(data)
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
if not isinstance(value, dict):
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
raw_steps = value.get("steps")
if not isinstance(raw_steps, list):
raise ValueError("training metrics 'steps' must be a list")
steps: list[MetricStep] = []
previous_step: int | None = None
for index, raw_step in enumerate(raw_steps):
if not isinstance(raw_step, dict):
raise ValueError(f"training metrics step {index} must be an object")
step = raw_step.get("step")
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
raise ValueError(f"training metrics step {index} has an invalid 'step'")
if previous_step is not None and step <= previous_step:
raise ValueError("training metrics steps must be unique and strictly increasing")
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
steps.append(MetricStep(step=step, metrics=metrics))
previous_step = step
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
return TrainingMetrics(steps=steps, summary=summary, raw=value)
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
with tarfile.open(archive_path, mode="r:*") as archive:
matches = [
member
for member in archive.getmembers()
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
]
if not matches:
return None
if len(matches) > 1:
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
raise ValueError(
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
)
extracted = archive.extractfile(matches[0])
if extracted is None:
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
return extracted.read()
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
if not isinstance(value, dict):
raise ValueError(f"{context} 'metrics' must be an object")
metrics: dict[str, float] = {}
for raw_name, raw_value in value.items():
if not isinstance(raw_name, str) or not raw_name:
raise ValueError(f"{context} contains an invalid metric name")
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
metric_value = float(raw_value)
if not math.isfinite(metric_value):
raise ValueError(f"{context} metric '{raw_name}' must be finite")
metrics[raw_name] = metric_value
return metrics

View File

@@ -1,4 +1,5 @@
import os
import tempfile
from dataclasses import dataclass
from typing import Any, Protocol
@@ -6,13 +7,29 @@ import mlflow
from mlflow.tracking import MlflowClient
from src.aws import mlflow as aws_mlflow
from src.aws import s3
from src.config import Config, MlflowMode
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
@dataclass(frozen=True)
class FinalizeResult:
registered_model_version: str | None = None
warnings: tuple[str, ...] = ()
class Tracker(Protocol):
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ...
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult: ...
@dataclass(frozen=True)
@@ -20,8 +37,16 @@ class NoopTracker:
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
return None
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
return None
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
return FinalizeResult()
@dataclass(frozen=True)
@@ -88,10 +113,19 @@ class MlflowTracker:
mlflow.end_run()
return run_id
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
if not run_id:
return None
return FinalizeResult()
warnings: list[str] = []
with mlflow.start_run(run_id=run_id):
self._log_params(
{
@@ -103,14 +137,22 @@ class MlflowTracker:
}
)
self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status")
if training_job_status.status == "Completed" and training_job_status.model_artifacts:
warnings.extend(
self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
)
mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return None
return FinalizeResult(warnings=tuple(warnings))
if not self.register_trained_models:
return None
return FinalizeResult(warnings=tuple(warnings))
client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name)
@@ -129,7 +171,7 @@ class MlflowTracker:
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number)
return version_number
return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings))
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -146,6 +188,29 @@ class MlflowTracker:
if metrics:
mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]:
try:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."]
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
except Exception as exc:
return [f"Could not import training metrics: {exc}"]
return []
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try:
client.get_registered_model(name)