include-metrics-from-training (#6)

Reviewed-on: #6
This commit was merged in pull request #6.
This commit is contained in:
2026-06-12 18:23:25 +00:00
parent 522ddc74e2
commit a1ffbb77c5
13 changed files with 785 additions and 116 deletions

View File

@@ -105,7 +105,11 @@ mlflow:
tracking_server_name: your-tracking-server-name tracking_server_name: your-tracking-server-name
``` ```
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release. When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through
`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts
as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only.
An experiment version is an immutable trained-source artifact; it records that training produced a model, not that
the model is better than earlier versions or ready for release.
To open the managed SageMaker MLflow UI, request a fresh presigned URL: To open the managed SageMaker MLflow UI, request a fresh presigned URL:
@@ -128,9 +132,14 @@ qc-cli init --force Overwrite an existing config file
### `mlflow` ### `mlflow`
``` ```
qc-cli mlflow open Open a presigned MLflow UI URL in a browser qc-cli mlflow open Open a presigned MLflow UI URL
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
``` ```
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run,
imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`.
Use `--force` to upload the metrics again.
### `infra` ### `infra`
``` ```
@@ -163,6 +172,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
``` ```
qc-cli train start Submit a SageMaker training job qc-cli train start Submit a SageMaker training job
qc-cli train start --upload-metrics Submit, wait, and upload metrics
qc-cli train status [job-name] Show job status; defaults to the last submitted job qc-cli train status [job-name] Show job status; defaults to the last submitted job
qc-cli train list List recent training jobs qc-cli train list List recent training jobs
qc-cli train list --limit 3 Show a custom number of recent jobs qc-cli train list --limit 3 Show a custom number of recent jobs
@@ -170,6 +180,8 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. `train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
### `ai-hub` ### `ai-hub`
@@ -216,13 +228,34 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
Current behavior: Current behavior:
1. `qc-cli train start` submits a SageMaker training job. 1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state. 2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: 3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
- `qc_cli.stage=experiment` - `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source` - `qc_cli.artifact_kind=trained_source`
- `qc_cli.source=sagemaker` - `qc_cli.source=sagemaker`
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version. 6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. 7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the
explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow
step and stores the JSON as a run artifact:
```json
{
"schema_version": 1,
"steps": [
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
],
"summary": {"summary.best_epoch": 0}
}
```
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and
strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues
model registration without per-epoch history. A malformed metrics artifact still fails the upload command without
affecting the trained model or model registration.
Future release aliases such as `v1` or `production` can point at a selected deployable artifact. Future release aliases such as `v1` or `production` can point at a selected deployable artifact.

View File

@@ -153,6 +153,20 @@ Or pass the job name explicitly:
qc-cli train status qc-cli-YYYYMMDD-HHMMSS qc-cli train status qc-cli-YYYYMMDD-HHMMSS
``` ```
To submit the job, wait for completion, and automatically import metrics and register the model, run:
```bash
qc-cli train start --upload-metrics
```
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
The metrics can be also submitted using:
```bash
qc-cli mlflow upload-metrics
```
## SageMaker Outputs ## SageMaker Outputs
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`. When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
@@ -163,10 +177,15 @@ This example writes:
best.pt best.pt
model.onnx model.onnx
metrics.json metrics.json
training_metrics.json
``` ```
The archive is stored under the configured `s3.model_prefix`. The archive is stored under the configured `s3.model_prefix`.
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
are more meaningful than classification accuracy when assessing model quality.
## 6. Configure Qualcomm AI Hub ## 6. Configure Qualcomm AI Hub
Authenticate with Qualcomm AI Hub: Authenticate with Qualcomm AI Hub:

View File

@@ -12,6 +12,7 @@ from typing import Any
import yaml import yaml
from sanitize_onnx import sanitize_onnx from sanitize_onnx import sanitize_onnx
from training_metrics import write_training_metrics
from ultralytics import YOLO # type: ignore[reportMissingImports] from ultralytics import YOLO # type: ignore[reportMissingImports]
@@ -101,6 +102,7 @@ def main() -> None:
if not trained_weights.exists(): if not trained_weights.exists():
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}") raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
copy_if_exists(trained_weights, model_dir / "best.pt") copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights)) trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz)) onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))

View File

@@ -0,0 +1,82 @@
import csv
import json
import math
import re
from pathlib import Path
from typing import Any
METRIC_NAMES = {
"metrics/precision(B)": "val.precision",
"metrics/recall(B)": "val.recall",
"metrics/mAP50(B)": "val.map50",
"metrics/mAP50-95(B)": "val.map50_95",
"train/box_loss": "train.box_loss",
"train/cls_loss": "train.cls_loss",
"train/dfl_loss": "train.dfl_loss",
"val/box_loss": "val.box_loss",
"val/cls_loss": "val.cls_loss",
"val/dfl_loss": "val.dfl_loss",
"time": "train.elapsed_seconds",
}
def write_training_metrics(results_csv: Path, destination: Path) -> None:
steps = _read_metric_steps(results_csv)
summary = _build_summary(steps)
payload = {
"schema_version": 1,
"steps": steps,
"summary": summary,
}
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print(f"Saved {destination}")
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
if not results_csv.is_file():
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
steps: list[dict[str, Any]] = []
with results_csv.open(newline="", encoding="utf-8") as csv_file:
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
row = {str(key).strip(): value for key, value in raw_row.items()}
raw_epoch = row.pop("epoch", row_index)
step = int(float(raw_epoch))
metrics: dict[str, float] = {}
for source_name, raw_value in row.items():
if raw_value is None or not raw_value.strip():
continue
try:
value = float(raw_value)
except ValueError:
continue
if math.isfinite(value):
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
steps.append({"step": step, "metrics": metrics})
return steps
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
if not steps:
return {}
summary: dict[str, float] = {}
final_step = steps[-1]
summary["summary.final_epoch"] = float(final_step["step"])
for name, value in final_step["metrics"].items():
summary[f"summary.final.{name}"] = value
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
if scored_steps:
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
summary["summary.best_epoch"] = float(best_step["step"])
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
if "val.map50" in best_step["metrics"]:
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
return summary
def _normalize_metric_name(name: str) -> str:
normalized = name.replace("/", ".")
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
return normalized.strip("._") or "unnamed"

View File

@@ -1,3 +1,6 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast from typing import Any, cast
import boto3 import boto3
@@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"]) return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

View File

@@ -0,0 +1 @@
"""Cloud provider adapters."""

77
src/cloud/mlflow.py Normal file
View File

@@ -0,0 +1,77 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

View File

@@ -2,8 +2,11 @@ import webbrowser
import typer import typer
from src import state as state_ops
from src.aws import mlflow as aws_mlflow from src.aws import mlflow as aws_mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import MlflowMode
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage MLflow tracking server access") app = typer.Typer(help="Manage MLflow tracking server access")
@@ -39,3 +42,54 @@ def open_mlflow(config: str = CONFIG_OPT) -> None:
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.") CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
else: else:
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]") CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
@app.command(name="upload-metrics")
def upload_metrics(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a completed training job's metric history to MLflow."""
cfg = load_cfg(config)
if cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
return
try:
result = upload_training_metrics(
job_name=job_name,
config_path=config,
cfg=cfg,
force=force,
)
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
raise typer.Exit(1)
if result.metrics_history_uploaded:
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
else:
CONSOLE.print(
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
)
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -11,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
from src.infra.state import read_infra_state from src.infra.state import read_infra_state
from src.tracking.mlflow import MlflowTracker from src.tracking.mlflow import MlflowTracker
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage SageMaker training jobs") app = typer.Typer(help="Manage SageMaker training jobs")
@@ -21,6 +23,8 @@ _STATUS_COLOR = {
"Stopping": "yellow", "Stopping": "yellow",
"Stopped": "dim", "Stopped": "dim",
} }
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
DEFAULT_POLL_INTERVAL_SECONDS = 30
def _tracker(cfg): def _tracker(cfg):
@@ -48,11 +52,100 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.") raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
def _wait_and_upload_metrics(
*,
job_name: str,
poll_interval: int,
config_path: str,
cfg: Config,
) -> None:
st = state_ops.store(config_path)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
if training_status.status != "Completed":
raise typer.Exit(1)
try:
result = upload_training_metrics(
job_name=job_name,
config_path=config_path,
cfg=cfg,
)
if result.metrics_history_uploaded:
CONSOLE.print(
f"[green]✓[/green] Uploaded training metrics to MLflow run "
f"[cyan]{result.run_id}[/cyan]."
)
else:
CONSOLE.print(
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
"Uploaded SageMaker final metrics only.[/yellow]"
)
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
CONSOLE.print(
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
raise typer.Exit(1)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command() @app.command()
def start(config: str = CONFIG_OPT) -> None: def start(
upload_metrics: bool = typer.Option(
False,
"--upload-metrics",
help="Wait for completion, then upload training metrics to MLflow",
),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between status checks when --upload-metrics is used",
),
config: str = CONFIG_OPT,
) -> None:
"""Submit a SageMaker training job.""" """Submit a SageMaker training job."""
cfg = load_cfg(config) cfg = load_cfg(config)
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
raise typer.Exit(1)
if not cfg.sagemaker.training.image_uri: if not cfg.sagemaker.training.image_uri:
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]") CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
CONSOLE.print( CONSOLE.print(
@@ -89,12 +182,20 @@ def start(config: str = CONFIG_OPT) -> None:
st = state_ops.store(config) st = state_ops.store(config)
st.set_last_training_job(job_name) st.set_last_training_job(job_name)
try:
run_id = tracker.start_training_run( run_id = tracker.start_training_run(
training_job, training_job,
region=cfg.aws.region, region=cfg.aws.region,
profile=cfg.aws.profile, profile=cfg.aws.profile,
role_arn=role_arn, role_arn=role_arn,
) )
except Exception as e:
run_id = None
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
CONSOLE.print(
"The SageMaker job is still running. Upload metrics after completion with "
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
if run_id: if run_id:
st.update_training_job(job_name, mlflow_run_id=run_id) st.update_training_job(job_name, mlflow_run_id=run_id)
@@ -102,6 +203,14 @@ def start(config: str = CONFIG_OPT) -> None:
if run_id: if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
if upload_metrics:
_wait_and_upload_metrics(
job_name=job_name,
poll_interval=poll_interval,
config_path=config,
cfg=cfg,
)
else:
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@@ -123,35 +232,7 @@ def status(
raise typer.Exit(1) raise typer.Exit(1)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
color = _STATUS_COLOR.get(status.status, "white") _print_training_status(status)
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
job_state = st.get_training_job(job_name)
run_id = job_state.get("mlflow_run_id")
already_registered = job_state.get("registered_model_version")
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
tracker = _tracker(cfg)
version = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
)
updates = {"mlflow_finalized_status": status.status}
if version:
updates["registered_model_version"] = version
st.update_training_job(job_name, **updates)
if version:
st.set_latest_experiment_model_version(version)
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command(name="list") @app.command(name="list")

View File

@@ -1,3 +1,3 @@
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"] __all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]

93
src/tracking/metrics.py Normal file
View File

@@ -0,0 +1,93 @@
import json
import math
import tarfile
from dataclasses import dataclass
from pathlib import PurePosixPath
from typing import Any
METRICS_ARTIFACT_NAME = "training_metrics.json"
METRICS_SCHEMA_VERSION = 1
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
@dataclass(frozen=True)
class MetricStep:
step: int
metrics: dict[str, float]
@dataclass(frozen=True)
class TrainingMetrics:
steps: list[MetricStep]
summary: dict[str, float]
raw: dict[str, Any]
def parse_training_metrics(data: bytes) -> TrainingMetrics:
try:
value = json.loads(data)
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
if not isinstance(value, dict):
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
raw_steps = value.get("steps")
if not isinstance(raw_steps, list):
raise ValueError("training metrics 'steps' must be a list")
steps: list[MetricStep] = []
previous_step: int | None = None
for index, raw_step in enumerate(raw_steps):
if not isinstance(raw_step, dict):
raise ValueError(f"training metrics step {index} must be an object")
step = raw_step.get("step")
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
raise ValueError(f"training metrics step {index} has an invalid 'step'")
if previous_step is not None and step <= previous_step:
raise ValueError("training metrics steps must be unique and strictly increasing")
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
steps.append(MetricStep(step=step, metrics=metrics))
previous_step = step
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
return TrainingMetrics(steps=steps, summary=summary, raw=value)
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
with tarfile.open(archive_path, mode="r:*") as archive:
matches = [
member
for member in archive.getmembers()
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
]
if not matches:
return None
if len(matches) > 1:
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
raise ValueError(
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
)
extracted = archive.extractfile(matches[0])
if extracted is None:
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
return extracted.read()
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
if not isinstance(value, dict):
raise ValueError(f"{context} 'metrics' must be an object")
metrics: dict[str, float] = {}
for raw_name, raw_value in value.items():
if not isinstance(raw_name, str) or not raw_name:
raise ValueError(f"{context} contains an invalid metric name")
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
metric_value = float(raw_value)
if not math.isfinite(metric_value):
raise ValueError(f"{context} metric '{raw_name}' must be finite")
metrics[raw_name] = metric_value
return metrics

View File

@@ -1,18 +1,46 @@
import os import os
import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Protocol from typing import Any, Protocol
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from src.aws import mlflow as aws_mlflow from src.aws import s3
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
@dataclass(frozen=True)
class FinalizeResult:
registered_model_version: str | None = None
warnings: tuple[str, ...] = ()
class Tracker(Protocol): class Tracker(Protocol):
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ... def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ... def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult: ...
def ensure_training_run(self, job_name: str) -> str: ...
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool: ...
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -20,8 +48,29 @@ class NoopTracker:
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
return None return None
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: def finalize_training_run(
return None self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
return FinalizeResult()
def ensure_training_run(self, job_name: str) -> str:
raise RuntimeError("MLflow is disabled.")
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool:
raise RuntimeError("MLflow is disabled.")
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -30,6 +79,7 @@ class MlflowTracker:
experiment_name: str experiment_name: str
registered_model_name: str registered_model_name: str
register_trained_models: bool register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod @classmethod
def from_config(cls, cfg: Config) -> Tracker: def from_config(cls, cfg: Config) -> Tracker:
@@ -42,11 +92,10 @@ class MlflowTracker:
if not tracking_server_name: if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.") raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_uri = aws_mlflow.get_tracking_server_arn( tracking_backend = mlflow_tracking_backend_from_config(cfg)
cfg.aws.region,
cfg.aws.profile, tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
tracking_server_name, with tracking_backend.auth_env():
)
mlflow.set_tracking_uri(tracking_uri) mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name) mlflow.set_experiment(cfg.mlflow.experiment_name)
@@ -55,62 +104,57 @@ class MlflowTracker:
experiment_name=cfg.mlflow.experiment_name, experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name, registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models, register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
) )
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
run = mlflow.start_run(run_name=training_job.job_name) with self.tracking_backend.auth_env():
with mlflow.start_run(run_name=training_job.job_name) as run:
run_id = str(run.info.run_id) run_id = str(run.info.run_id)
self._log_params(
params = { self.tracking_backend.training_run_params(
"aws.region": region, training_job,
"aws.profile": profile, region=region,
"sagemaker.role_arn": role_arn, profile=profile,
"sagemaker.job_name": training_job.job_name, role_arn=role_arn,
"sagemaker.training_image": training_job.image_uri, )
"sagemaker.instance_type": training_job.instance_type, )
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags( mlflow.set_tags(
{ {
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker", "qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "train start", "qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name, **self.tracking_backend.training_run_tags(training_job),
} }
) )
mlflow.end_run()
return run_id return run_id
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
if not run_id: if not run_id:
return None return FinalizeResult()
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params( self._log_params(self.tracking_backend.training_status_params(training_job_status))
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status") mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts: if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return None return FinalizeResult()
if not self.register_trained_models: if not self.register_trained_models:
return None return FinalizeResult()
client = MlflowClient() client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name) self._ensure_registered_model(client, self.registered_model_name)
@@ -121,15 +165,65 @@ class MlflowTracker:
tags={ tags={
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker", "qc_cli.source": self.tracking_backend.provider_name,
"sagemaker.job_name": training_job_status.name, **self.tracking_backend.model_version_tags(training_job_status),
}, },
) )
version_number = str(version.version) version_number = str(version.version)
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number) mlflow.set_tag("qc_cli.registered_model_version", version_number)
return version_number return FinalizeResult(registered_model_version=version_number)
def ensure_training_run(self, job_name: str) -> str:
with self.tracking_backend.auth_env():
client = MlflowClient()
experiment = client.get_experiment_by_name(self.experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(self.experiment_name)
else:
experiment_id = experiment.experiment_id
for run in client.search_runs([experiment_id], max_results=1000):
if run.data.tags.get("sagemaker.job_name") == job_name:
return str(run.info.run_id)
run = client.create_run(
experiment_id,
run_name=job_name,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "mlflow upload-metrics",
"sagemaker.job_name": job_name,
},
)
return str(run.info.run_id)
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool:
if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
history_uploaded = self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
return history_uploaded
def _log_params(self, params: dict[str, Any]) -> None: def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None} cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -146,6 +240,26 @@ class MlflowTracker:
if metrics: if metrics:
mlflow.log_metrics(metrics) mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
return False
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
return True
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try: try:
client.get_registered_model(name) client.get_registered_model(name)

75
src/tracking/upload.py Normal file
View File

@@ -0,0 +1,75 @@
from dataclasses import dataclass
from src import state as state_ops
from src.aws import sagemaker as sm_ops
from src.config import Config, MlflowMode
from src.tracking.mlflow import MlflowTracker
@dataclass(frozen=True)
class MetricsUploadResult:
run_id: str
registered_model_version: str | None = None
metrics_history_uploaded: bool = True
def upload_training_metrics(
*,
job_name: str,
config_path: str,
cfg: Config,
force: bool = False,
) -> MetricsUploadResult:
if cfg.mlflow.mode is MlflowMode.disabled:
raise RuntimeError("MLflow is disabled in config.yaml.")
st = state_ops.store(config_path)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_metrics_uploaded") and not force:
return MetricsUploadResult(
run_id=str(job_state.get("mlflow_run_id") or ""),
registered_model_version=(
str(job_state["registered_model_version"])
if job_state.get("registered_model_version")
else None
),
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if status.status != "Completed":
raise RuntimeError(
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
)
tracker = MlflowTracker.from_config(cfg)
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
st.update_training_job(job_name, mlflow_run_id=run_id)
metrics_history_uploaded = tracker.upload_training_metrics(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
)
finalized = tracker.finalize_training_run(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command="mlflow upload-metrics",
)
updates = {
"mlflow_metrics_uploaded": True,
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
"mlflow_finalized_status": status.status,
}
if finalized.registered_model_version:
updates["registered_model_version"] = finalized.registered_model_version
st.update_training_job(job_name, **updates)
if finalized.registered_model_version:
st.set_latest_experiment_model_version(finalized.registered_model_version)
return MetricsUploadResult(
run_id=run_id,
registered_model_version=finalized.registered_model_version,
metrics_history_uploaded=metrics_history_uploaded,
)