From 3846c5d88d78759ba97b7257f80e4a39666a6ea6 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 5 Jun 2026 15:52:55 -0400 Subject: [PATCH 1/6] add aws context for MLFlow --- src/aws/mlflow.py | 38 ++++++++++++ src/cloud/__init__.py | 1 + src/cloud/mlflow.py | 77 ++++++++++++++++++++++++ src/tracking/mlflow.py | 131 +++++++++++++++++++---------------------- 4 files changed, 176 insertions(+), 71 deletions(-) create mode 100644 src/cloud/__init__.py create mode 100644 src/cloud/mlflow.py diff --git a/src/aws/mlflow.py b/src/aws/mlflow.py index 35433bb..344e70f 100644 --- a/src/aws/mlflow.py +++ b/src/aws/mlflow.py @@ -1,3 +1,6 @@ +import os +from collections.abc import Generator +from contextlib import contextmanager from typing import Any, cast import boto3 @@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) - client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) return str(response["AuthorizedUrl"]) + + +@contextmanager +def tracking_auth_env(profile: str, region: str) -> Generator[None]: + credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials() + if credentials is None: + raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.") + + frozen_credentials = credentials.get_frozen_credentials() + if not frozen_credentials.access_key or not frozen_credentials.secret_key: + raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.") + + env_updates = { + "AWS_PROFILE": profile, + "AWS_DEFAULT_REGION": region, + "AWS_REGION": region, + "AWS_ACCESS_KEY_ID": frozen_credentials.access_key, + "AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key, + } + if frozen_credentials.token: + env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token + + restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"} + previous_env = {key: os.environ.get(key) for key in restore_keys} + try: + os.environ.update(env_updates) + if not frozen_credentials.token: + os.environ.pop("AWS_SESSION_TOKEN", None) + yield + finally: + for key, value in previous_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/src/cloud/__init__.py b/src/cloud/__init__.py new file mode 100644 index 0000000..75bf900 --- /dev/null +++ b/src/cloud/__init__.py @@ -0,0 +1 @@ +"""Cloud provider adapters.""" diff --git a/src/cloud/mlflow.py b/src/cloud/mlflow.py new file mode 100644 index 0000000..08ae203 --- /dev/null +++ b/src/cloud/mlflow.py @@ -0,0 +1,77 @@ +from contextlib import AbstractContextManager +from dataclasses import dataclass +from typing import Any, Protocol + +from src.aws import mlflow as aws_mlflow +from src.config import Config + + +class MlflowTrackingBackend(Protocol): + @property + def provider_name(self) -> str: ... + + @property + def profile(self) -> str: ... + + @property + def region(self) -> str: ... + + def get_tracking_uri(self, tracking_server_name: str) -> str: ... + + def auth_env(self) -> AbstractContextManager[None]: ... + + def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ... + + def training_run_tags(self, training_job: Any) -> dict[str, Any]: ... + + def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ... + + def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ... + + +@dataclass(frozen=True) +class AwsMlflowTrackingBackend: + profile: str + region: str + provider_name: str = "aws" + + def get_tracking_uri(self, tracking_server_name: str) -> str: + return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name) + + def auth_env(self) -> AbstractContextManager[None]: + return aws_mlflow.tracking_auth_env(self.profile, self.region) + + def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: + return { + "provider.name": self.provider_name, + "provider.region": region, + "provider.profile": profile, + "sagemaker.role_arn": role_arn, + "sagemaker.job_name": training_job.job_name, + "sagemaker.training_image": training_job.image_uri, + "sagemaker.instance_type": training_job.instance_type, + "sagemaker.instance_count": training_job.instance_count, + "sagemaker.s3_train_uri": training_job.s3_train_uri, + "sagemaker.s3_output_path": training_job.s3_output_path, + "sagemaker.entry_point": training_job.entry_point, + "sagemaker.source_dir": training_job.source_dir, + } + + def training_run_tags(self, training_job: Any) -> dict[str, Any]: + return {"sagemaker.job_name": training_job.job_name} + + def training_status_params(self, training_job_status: Any) -> dict[str, Any]: + return { + "sagemaker.training_status": training_job_status.status, + "sagemaker.created_at": training_job_status.created, + "sagemaker.modified_at": training_job_status.modified, + "sagemaker.model_artifacts": training_job_status.model_artifacts, + "sagemaker.failure_reason": training_job_status.failure_reason, + } + + def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: + return {"sagemaker.job_name": training_job_status.name} + + +def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend: + return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region) diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index 0e8f5d0..2483870 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -5,7 +5,7 @@ from typing import Any, Protocol import mlflow from mlflow.tracking import MlflowClient -from src.aws import mlflow as aws_mlflow +from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config from src.config import Config, MlflowMode @@ -30,6 +30,7 @@ class MlflowTracker: experiment_name: str registered_model_name: str register_trained_models: bool + tracking_backend: MlflowTrackingBackend @classmethod def from_config(cls, cfg: Config) -> Tracker: @@ -42,94 +43,82 @@ class MlflowTracker: if not tracking_server_name: raise RuntimeError("MLflow tracking server name could not be resolved.") - tracking_uri = aws_mlflow.get_tracking_server_arn( - cfg.aws.region, - cfg.aws.profile, - tracking_server_name, - ) - mlflow.set_tracking_uri(tracking_uri) - mlflow.set_experiment(cfg.mlflow.experiment_name) + tracking_backend = mlflow_tracking_backend_from_config(cfg) + + tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name) + with tracking_backend.auth_env(): + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(cfg.mlflow.experiment_name) return cls( tracking_uri=tracking_uri, experiment_name=cfg.mlflow.experiment_name, registered_model_name=cfg.mlflow.registered_model_name, register_trained_models=cfg.mlflow.register_trained_models, + tracking_backend=tracking_backend, ) def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: - run = mlflow.start_run(run_name=training_job.job_name) - run_id = str(run.info.run_id) + with self.tracking_backend.auth_env(): + run = mlflow.start_run(run_name=training_job.job_name) + run_id = str(run.info.run_id) - params = { - "aws.region": region, - "aws.profile": profile, - "sagemaker.role_arn": role_arn, - "sagemaker.job_name": training_job.job_name, - "sagemaker.training_image": training_job.image_uri, - "sagemaker.instance_type": training_job.instance_type, - "sagemaker.instance_count": training_job.instance_count, - "sagemaker.s3_train_uri": training_job.s3_train_uri, - "sagemaker.s3_output_path": training_job.s3_output_path, - "sagemaker.entry_point": training_job.entry_point, - "sagemaker.source_dir": training_job.source_dir, - } - self._log_params(params) - self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) - mlflow.set_tags( - { - "qc_cli.stage": "experiment", - "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "qc_cli.command": "train start", - "sagemaker.job_name": training_job.job_name, - } - ) - mlflow.end_run() - return run_id + self._log_params( + self.tracking_backend.training_run_params( + training_job, + region=region, + profile=profile, + role_arn=role_arn, + ) + ) + self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) + mlflow.set_tags( + { + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": self.tracking_backend.provider_name, + "qc_cli.command": "train start", + **self.tracking_backend.training_run_tags(training_job), + } + ) + mlflow.end_run() + return run_id def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: if not run_id: return None - with mlflow.start_run(run_id=run_id): - self._log_params( - { - "sagemaker.training_status": training_job_status.status, - "sagemaker.created_at": training_job_status.created, - "sagemaker.modified_at": training_job_status.modified, - "sagemaker.model_artifacts": training_job_status.model_artifacts, - "sagemaker.failure_reason": training_job_status.failure_reason, - } - ) - self._log_final_metrics(training_job_status.raw) - mlflow.set_tag("qc_cli.command", "train status") + with self.tracking_backend.auth_env(): + with mlflow.start_run(run_id=run_id): + self._log_params(self.tracking_backend.training_status_params(training_job_status)) + self._log_final_metrics(training_job_status.raw) + mlflow.set_tag("qc_cli.command", "train status") - if training_job_status.status != "Completed" or not training_job_status.model_artifacts: - mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) - return None + if training_job_status.status != "Completed" or not training_job_status.model_artifacts: + mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) + return None - if not self.register_trained_models: - return None + if not self.register_trained_models: + return None - client = MlflowClient() - self._ensure_registered_model(client, self.registered_model_name) - version = client.create_model_version( - name=self.registered_model_name, - source=training_job_status.model_artifacts, - run_id=run_id, - tags={ - "qc_cli.stage": "experiment", - "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "sagemaker.job_name": training_job_status.name, - }, - ) - version_number = str(version.version) - client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) - mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) - mlflow.set_tag("qc_cli.registered_model_version", version_number) - return version_number + client = MlflowClient() + self._ensure_registered_model(client, self.registered_model_name) + version = client.create_model_version( + name=self.registered_model_name, + source=training_job_status.model_artifacts, + run_id=run_id, + tags={ + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": self.tracking_backend.provider_name, + **self.tracking_backend.model_version_tags(training_job_status), + }, + ) + version_number = str(version.version) + client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) + mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) + mlflow.set_tag("qc_cli.registered_model_version", version_number) + return version_number def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} -- 2.49.1 From 2d4d3770510830ee6161df835fbda266e730bcf8 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 11:42:26 -0400 Subject: [PATCH 2/6] WIP --- README.md | 19 ++- examples/meter-detection/README.md | 11 ++ examples/meter-detection/source/train.py | 2 + .../source/training_metrics.py | 82 +++++++++++ src/commands/train.py | 134 ++++++++++++++---- src/tracking/__init__.py | 4 +- src/tracking/metrics.py | 93 ++++++++++++ src/tracking/mlflow.py | 83 +++++++++-- 8 files changed, 390 insertions(+), 38 deletions(-) create mode 100644 examples/meter-detection/source/training_metrics.py create mode 100644 src/tracking/metrics.py diff --git a/README.md b/README.md index 3a61e88..770bedd 100644 --- a/README.md +++ b/README.md @@ -164,12 +164,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de ``` qc-cli train start Submit a SageMaker training job qc-cli train status [job-name] Show job status; defaults to the last submitted job +qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking qc-cli train list List recent training jobs qc-cli train list --limit 3 Show a custom number of recent jobs ``` `train start` uses `s3:////` as the training channel and writes outputs under `s3:////`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. +`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. + The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. ### `ai-hub` @@ -216,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state. +2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default. 3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` @@ -224,6 +227,20 @@ Current behavior: 4. The MLflow alias `experiment-latest` points at the most recently registered experiment version. 5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. +Training scripts can include a `training_metrics.json` file in the SageMaker model directory. During finalization, the CLI logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact: + +```json +{ + "schema_version": 1, + "steps": [ + {"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}} + ], + "summary": {"summary.best_epoch": 0} +} +``` + +Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. Missing or malformed metrics produce a warning but do not block model registration. + Future release aliases such as `v1` or `production` can point at a selected deployable artifact. Example future metadata: diff --git a/examples/meter-detection/README.md b/examples/meter-detection/README.md index 35955df..a85a3a5 100644 --- a/examples/meter-detection/README.md +++ b/examples/meter-detection/README.md @@ -153,6 +153,14 @@ Or pass the job name explicitly: qc-cli train status qc-cli-YYYYMMDD-HHMMSS ``` +To wait for completion and automatically import metrics and register the model, run: + +```bash +qc-cli train wait +``` + +The default polling interval is 30 seconds. It can be changed with `--poll-interval `. + ## SageMaker Outputs When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`. @@ -163,10 +171,13 @@ This example writes: best.pt model.onnx metrics.json +training_metrics.json ``` The archive is stored under the configured `s3.model_prefix`. +During MLflow finalization, `training_metrics.json` provides per-epoch training and validation losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall are more meaningful than classification accuracy when assessing model quality. + ## 6. Configure Qualcomm AI Hub Authenticate with Qualcomm AI Hub: diff --git a/examples/meter-detection/source/train.py b/examples/meter-detection/source/train.py index 99aef51..f40b872 100644 --- a/examples/meter-detection/source/train.py +++ b/examples/meter-detection/source/train.py @@ -12,6 +12,7 @@ from typing import Any import yaml from sanitize_onnx import sanitize_onnx +from training_metrics import write_training_metrics from ultralytics import YOLO # type: ignore[reportMissingImports] @@ -101,6 +102,7 @@ def main() -> None: if not trained_weights.exists(): raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}") + write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json") copy_if_exists(trained_weights, model_dir / "best.pt") trained_model = YOLO(str(trained_weights)) onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz)) diff --git a/examples/meter-detection/source/training_metrics.py b/examples/meter-detection/source/training_metrics.py new file mode 100644 index 0000000..004c557 --- /dev/null +++ b/examples/meter-detection/source/training_metrics.py @@ -0,0 +1,82 @@ +import csv +import json +import math +import re +from pathlib import Path +from typing import Any + +METRIC_NAMES = { + "metrics/precision(B)": "val.precision", + "metrics/recall(B)": "val.recall", + "metrics/mAP50(B)": "val.map50", + "metrics/mAP50-95(B)": "val.map50_95", + "train/box_loss": "train.box_loss", + "train/cls_loss": "train.cls_loss", + "train/dfl_loss": "train.dfl_loss", + "val/box_loss": "val.box_loss", + "val/cls_loss": "val.cls_loss", + "val/dfl_loss": "val.dfl_loss", + "time": "train.elapsed_seconds", +} + + +def write_training_metrics(results_csv: Path, destination: Path) -> None: + steps = _read_metric_steps(results_csv) + summary = _build_summary(steps) + payload = { + "schema_version": 1, + "steps": steps, + "summary": summary, + } + destination.write_text(json.dumps(payload, indent=2), encoding="utf-8") + print(f"Saved {destination}") + + +def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]: + if not results_csv.is_file(): + raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}") + + steps: list[dict[str, Any]] = [] + with results_csv.open(newline="", encoding="utf-8") as csv_file: + for row_index, raw_row in enumerate(csv.DictReader(csv_file)): + row = {str(key).strip(): value for key, value in raw_row.items()} + raw_epoch = row.pop("epoch", row_index) + step = int(float(raw_epoch)) + metrics: dict[str, float] = {} + for source_name, raw_value in row.items(): + if raw_value is None or not raw_value.strip(): + continue + try: + value = float(raw_value) + except ValueError: + continue + if math.isfinite(value): + metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value + steps.append({"step": step, "metrics": metrics}) + return steps + + +def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]: + if not steps: + return {} + + summary: dict[str, float] = {} + final_step = steps[-1] + summary["summary.final_epoch"] = float(final_step["step"]) + for name, value in final_step["metrics"].items(): + summary[f"summary.final.{name}"] = value + + scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]] + if scored_steps: + best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"]) + summary["summary.best_epoch"] = float(best_step["step"]) + summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"] + if "val.map50" in best_step["metrics"]: + summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"] + return summary + + +def _normalize_metric_name(name: str) -> str: + normalized = name.replace("/", ".") + normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized) + return normalized.strip("._") or "unnamed" diff --git a/src/commands/train.py b/src/commands/train.py index eb96c92..5958514 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from pathlib import Path @@ -21,6 +22,8 @@ _STATUS_COLOR = { "Stopping": "yellow", "Stopped": "dim", } +_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"} +DEFAULT_POLL_INTERVAL_SECONDS = 30 def _tracker(cfg): @@ -48,6 +51,57 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str: raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.") +def _print_training_status(status: sm_ops.TrainingJobStatus) -> None: + color = _STATUS_COLOR.get(status.status, "white") + CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]") + CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]") + if status.created: + CONSOLE.print(f"Created: {status.created}") + if status.model_artifacts: + CONSOLE.print(f"Artifacts: {status.model_artifacts}") + if status.failure_reason: + CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") + + +def _finalize_terminal_job( + *, + config_path: str, + cfg: Config, + status: sm_ops.TrainingJobStatus, + command: str, +) -> None: + if status.status not in _TERMINAL_STATUSES: + return + + st = state_ops.store(config_path) + job_state = st.get_training_job(status.name) + run_id = job_state.get("mlflow_run_id") + if not run_id or job_state.get("mlflow_finalized_status"): + return + + tracker = _tracker(cfg) + result = tracker.finalize_training_run( + run_id=str(run_id), + training_job_status=status, + region=cfg.aws.region, + profile=cfg.aws.profile, + command=command, + ) + updates = {"mlflow_finalized_status": status.status} + if result.registered_model_version: + updates["registered_model_version"] = result.registered_model_version + st.update_training_job(status.name, **updates) + + for warning in result.warnings: + CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]") + if result.registered_model_version: + st.set_latest_experiment_model_version(result.registered_model_version) + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) + + @app.command() def start(config: str = CONFIG_OPT) -> None: """Submit a SageMaker training job.""" @@ -123,37 +177,65 @@ def status( raise typer.Exit(1) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) - color = _STATUS_COLOR.get(status.status, "white") - - CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]") - CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]") - if status.created: - CONSOLE.print(f"Created: {status.created}") - if status.model_artifacts: - CONSOLE.print(f"Artifacts: {status.model_artifacts}") - if status.failure_reason: - CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") + _print_training_status(status) + _finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status") job_state = st.get_training_job(job_name) - run_id = job_state.get("mlflow_run_id") - already_registered = job_state.get("registered_model_version") - if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}: - tracker = _tracker(cfg) - version = tracker.finalize_training_run( - run_id=str(run_id), - training_job_status=status, - ) - updates = {"mlflow_finalized_status": status.status} - if version: - updates["registered_model_version"] = version - st.update_training_job(job_name, **updates) - if version: - st.set_latest_experiment_model_version(version) - CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])") - if run_id and cfg.mlflow.mode is not MlflowMode.disabled: + if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") +@app.command() +def wait( + job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), + poll_interval: int = typer.Option( + DEFAULT_POLL_INTERVAL_SECONDS, + "--poll-interval", + min=1, + help="Seconds between SageMaker status checks", + ), + config: str = CONFIG_OPT, +) -> None: + """Wait for a training job and finalize its MLflow run.""" + cfg = load_cfg(config) + st = state_ops.store(config) + if not job_name: + job_name = st.get_last_training_job() + if not job_name: + CONSOLE.print( + "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" + ) + raise typer.Exit(1) + + previous_status: str | None = None + try: + while True: + training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + if training_status.status != previous_status: + color = _STATUS_COLOR.get(training_status.status, "white") + CONSOLE.print( + f"Job [cyan]{training_status.name}[/cyan]: " + f"[{color}]{training_status.status}[/{color}]" + ) + previous_status = training_status.status + if training_status.status in _TERMINAL_STATUSES: + _print_training_status(training_status) + _finalize_terminal_job( + config_path=config, + cfg=cfg, + status=training_status, + command="train wait", + ) + job_state = st.get_training_job(job_name) + if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: + CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") + return + time.sleep(poll_interval) + except KeyboardInterrupt: + CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") + raise typer.Exit(130) + + @app.command(name="list") def list_jobs( limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"), diff --git a/src/tracking/__init__.py b/src/tracking/__init__.py index 931f89f..6e63597 100644 --- a/src/tracking/__init__.py +++ b/src/tracking/__init__.py @@ -1,3 +1,3 @@ -from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker +from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker -__all__ = ["MlflowTracker", "NoopTracker", "Tracker"] +__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"] diff --git a/src/tracking/metrics.py b/src/tracking/metrics.py new file mode 100644 index 0000000..bb73a80 --- /dev/null +++ b/src/tracking/metrics.py @@ -0,0 +1,93 @@ +import json +import math +import tarfile +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import Any + +METRICS_ARTIFACT_NAME = "training_metrics.json" +METRICS_SCHEMA_VERSION = 1 +MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024 + + +@dataclass(frozen=True) +class MetricStep: + step: int + metrics: dict[str, float] + + +@dataclass(frozen=True) +class TrainingMetrics: + steps: list[MetricStep] + summary: dict[str, float] + raw: dict[str, Any] + + +def parse_training_metrics(data: bytes) -> TrainingMetrics: + try: + value = json.loads(data) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc + if not isinstance(value, dict): + raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object") + if value.get("schema_version") != METRICS_SCHEMA_VERSION: + raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}") + + raw_steps = value.get("steps") + if not isinstance(raw_steps, list): + raise ValueError("training metrics 'steps' must be a list") + + steps: list[MetricStep] = [] + previous_step: int | None = None + for index, raw_step in enumerate(raw_steps): + if not isinstance(raw_step, dict): + raise ValueError(f"training metrics step {index} must be an object") + step = raw_step.get("step") + if isinstance(step, bool) or not isinstance(step, int) or step < 0: + raise ValueError(f"training metrics step {index} has an invalid 'step'") + if previous_step is not None and step <= previous_step: + raise ValueError("training metrics steps must be unique and strictly increasing") + metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}") + steps.append(MetricStep(step=step, metrics=metrics)) + previous_step = step + + summary = _numeric_metrics(value.get("summary", {}), "training metrics summary") + return TrainingMetrics(steps=steps, summary=summary, raw=value) + + +def read_training_metrics_from_tar(archive_path: str) -> bytes | None: + with tarfile.open(archive_path, mode="r:*") as archive: + matches = [ + member + for member in archive.getmembers() + if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME + ] + if not matches: + return None + if len(matches) > 1: + raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files") + if matches[0].size > MAX_METRICS_ARTIFACT_BYTES: + raise ValueError( + f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit" + ) + extracted = archive.extractfile(matches[0]) + if extracted is None: + raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive") + return extracted.read() + + +def _numeric_metrics(value: Any, context: str) -> dict[str, float]: + if not isinstance(value, dict): + raise ValueError(f"{context} 'metrics' must be an object") + + metrics: dict[str, float] = {} + for raw_name, raw_value in value.items(): + if not isinstance(raw_name, str) or not raw_name: + raise ValueError(f"{context} contains an invalid metric name") + if isinstance(raw_value, bool) or not isinstance(raw_value, int | float): + raise ValueError(f"{context} metric '{raw_name}' must be numeric") + metric_value = float(raw_value) + if not math.isfinite(metric_value): + raise ValueError(f"{context} metric '{raw_name}' must be finite") + metrics[raw_name] = metric_value + return metrics diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index 0e8f5d0..e125ae8 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -1,4 +1,5 @@ import os +import tempfile from dataclasses import dataclass from typing import Any, Protocol @@ -6,13 +7,29 @@ import mlflow from mlflow.tracking import MlflowClient from src.aws import mlflow as aws_mlflow +from src.aws import s3 from src.config import Config, MlflowMode +from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar + + +@dataclass(frozen=True) +class FinalizeResult: + registered_model_version: str | None = None + warnings: tuple[str, ...] = () class Tracker(Protocol): def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ... - def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ... + def finalize_training_run( + self, + *, + run_id: str | None, + training_job_status: Any, + region: str, + profile: str, + command: str, + ) -> FinalizeResult: ... @dataclass(frozen=True) @@ -20,8 +37,16 @@ class NoopTracker: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: return None - def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: - return None + def finalize_training_run( + self, + *, + run_id: str | None, + training_job_status: Any, + region: str, + profile: str, + command: str, + ) -> FinalizeResult: + return FinalizeResult() @dataclass(frozen=True) @@ -88,10 +113,19 @@ class MlflowTracker: mlflow.end_run() return run_id - def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: + def finalize_training_run( + self, + *, + run_id: str | None, + training_job_status: Any, + region: str, + profile: str, + command: str, + ) -> FinalizeResult: if not run_id: - return None + return FinalizeResult() + warnings: list[str] = [] with mlflow.start_run(run_id=run_id): self._log_params( { @@ -103,14 +137,22 @@ class MlflowTracker: } ) self._log_final_metrics(training_job_status.raw) - mlflow.set_tag("qc_cli.command", "train status") + if training_job_status.status == "Completed" and training_job_status.model_artifacts: + warnings.extend( + self._log_training_metrics( + training_job_status.model_artifacts, + region=region, + profile=profile, + ) + ) + mlflow.set_tag("qc_cli.command", command) if training_job_status.status != "Completed" or not training_job_status.model_artifacts: mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) - return None + return FinalizeResult(warnings=tuple(warnings)) if not self.register_trained_models: - return None + return FinalizeResult(warnings=tuple(warnings)) client = MlflowClient() self._ensure_registered_model(client, self.registered_model_name) @@ -129,7 +171,7 @@ class MlflowTracker: client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_version", version_number) - return version_number + return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings)) def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} @@ -146,6 +188,29 @@ class MlflowTracker: if metrics: mlflow.log_metrics(metrics) + def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]: + try: + with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: + archive_path = s3.download_file( + region, + profile, + model_artifacts, + os.path.join(temp_dir, "model.tar.gz"), + ) + metrics_data = read_training_metrics_from_tar(archive_path) + if metrics_data is None: + return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."] + metrics = parse_training_metrics(metrics_data) + for metric_step in metrics.steps: + if metric_step.metrics: + mlflow.log_metrics(metric_step.metrics, step=metric_step.step) + if metrics.summary: + mlflow.log_metrics(metrics.summary) + mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) + except Exception as exc: + return [f"Could not import training metrics: {exc}"] + return [] + def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: try: client.get_registered_model(name) -- 2.49.1 From 53e886a535822082b979351640a2c08c5b01cf11 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 11:57:27 -0400 Subject: [PATCH 3/6] update --- README.md | 6 +- examples/meter-detection/README.md | 4 +- src/commands/train.py | 109 +++++++++++++++-------------- 3 files changed, 61 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 770bedd..1901afc 100644 --- a/README.md +++ b/README.md @@ -163,15 +163,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de ``` qc-cli train start Submit a SageMaker training job +qc-cli train start --wait Submit, wait, and finalize MLflow tracking qc-cli train status [job-name] Show job status; defaults to the last submitted job -qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking qc-cli train list List recent training jobs qc-cli train list --limit 3 Show a custom number of recent jobs ``` `train start` uses `s3:////` as the training channel and writes outputs under `s3:////`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. -`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. +`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. @@ -219,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default. +2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default. 3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` diff --git a/examples/meter-detection/README.md b/examples/meter-detection/README.md index a85a3a5..e6d5c27 100644 --- a/examples/meter-detection/README.md +++ b/examples/meter-detection/README.md @@ -153,10 +153,10 @@ Or pass the job name explicitly: qc-cli train status qc-cli-YYYYMMDD-HHMMSS ``` -To wait for completion and automatically import metrics and register the model, run: +To submit the job, wait for completion, and automatically import metrics and register the model, run: ```bash -qc-cli train wait +qc-cli train start --wait ``` The default polling interval is 30 seconds. It can be changed with `--poll-interval `. diff --git a/src/commands/train.py b/src/commands/train.py index 5958514..3c38927 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -102,8 +102,54 @@ def _finalize_terminal_job( ) +def _wait_for_training_job( + *, + job_name: str, + poll_interval: int, + config_path: str, + cfg: Config, +) -> None: + st = state_ops.store(config_path) + previous_status: str | None = None + try: + while True: + training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + if training_status.status != previous_status: + color = _STATUS_COLOR.get(training_status.status, "white") + CONSOLE.print( + f"Job [cyan]{training_status.name}[/cyan]: " + f"[{color}]{training_status.status}[/{color}]" + ) + previous_status = training_status.status + if training_status.status in _TERMINAL_STATUSES: + _print_training_status(training_status) + _finalize_terminal_job( + config_path=config_path, + cfg=cfg, + status=training_status, + command="train start --wait", + ) + job_state = st.get_training_job(job_name) + if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: + CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") + return + time.sleep(poll_interval) + except KeyboardInterrupt: + CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") + raise typer.Exit(130) + + @app.command() -def start(config: str = CONFIG_OPT) -> None: +def start( + wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"), + poll_interval: int = typer.Option( + DEFAULT_POLL_INTERVAL_SECONDS, + "--poll-interval", + min=1, + help="Seconds between status checks when --wait is used", + ), + config: str = CONFIG_OPT, +) -> None: """Submit a SageMaker training job.""" cfg = load_cfg(config) @@ -156,7 +202,15 @@ def start(config: str = CONFIG_OPT) -> None: if run_id: CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") - CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") + if wait: + _wait_for_training_job( + job_name=job_name, + poll_interval=poll_interval, + config_path=config, + cfg=cfg, + ) + else: + CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") @app.command() @@ -185,57 +239,6 @@ def status( CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") -@app.command() -def wait( - job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), - poll_interval: int = typer.Option( - DEFAULT_POLL_INTERVAL_SECONDS, - "--poll-interval", - min=1, - help="Seconds between SageMaker status checks", - ), - config: str = CONFIG_OPT, -) -> None: - """Wait for a training job and finalize its MLflow run.""" - cfg = load_cfg(config) - st = state_ops.store(config) - if not job_name: - job_name = st.get_last_training_job() - if not job_name: - CONSOLE.print( - "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" - ) - raise typer.Exit(1) - - previous_status: str | None = None - try: - while True: - training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) - if training_status.status != previous_status: - color = _STATUS_COLOR.get(training_status.status, "white") - CONSOLE.print( - f"Job [cyan]{training_status.name}[/cyan]: " - f"[{color}]{training_status.status}[/{color}]" - ) - previous_status = training_status.status - if training_status.status in _TERMINAL_STATUSES: - _print_training_status(training_status) - _finalize_terminal_job( - config_path=config, - cfg=cfg, - status=training_status, - command="train wait", - ) - job_state = st.get_training_job(job_name) - if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: - CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") - return - time.sleep(poll_interval) - except KeyboardInterrupt: - CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") - raise typer.Exit(130) - - @app.command(name="list") def list_jobs( limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"), -- 2.49.1 From 5211d0af14b7ea2636d081fb928535a5b6f8b369 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 12:17:02 -0400 Subject: [PATCH 4/6] another update --- README.md | 29 +++-- examples/meter-detection/README.md | 12 +- src/commands/mlflow.py | 43 +++++++ src/commands/train.py | 60 +++++++--- src/tracking/mlflow.py | 185 +++++++++++++++++++---------- src/tracking/upload.py | 38 ++++++ 6 files changed, 278 insertions(+), 89 deletions(-) create mode 100644 src/tracking/upload.py diff --git a/README.md b/README.md index 1901afc..afb8dd7 100644 --- a/README.md +++ b/README.md @@ -128,9 +128,14 @@ qc-cli init --force Overwrite an existing config file ### `mlflow` ``` -qc-cli mlflow open Open a presigned MLflow UI URL in a browser +qc-cli mlflow open Open a presigned MLflow UI URL +qc-cli mlflow upload-metrics [job-name] Upload completed training metrics ``` +`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run, +imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`. +Use `--force` to upload the metrics again. + ### `infra` ``` @@ -163,7 +168,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de ``` qc-cli train start Submit a SageMaker training job -qc-cli train start --wait Submit, wait, and finalize MLflow tracking +qc-cli train start --upload-metrics Submit, wait, and upload metrics qc-cli train status [job-name] Show job status; defaults to the last submitted job qc-cli train list List recent training jobs qc-cli train list --limit 3 Show a custom number of recent jobs @@ -171,7 +176,7 @@ qc-cli train list --limit 3 Show a custom number of recent jobs `train start` uses `s3:////` as the training channel and writes outputs under `s3:////`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. -`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. +`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. @@ -219,15 +224,19 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default. -3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: +2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts. +3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion. +4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job. +5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` - `qc_cli.source=sagemaker` -4. The MLflow alias `experiment-latest` points at the most recently registered experiment version. -5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. +6. The MLflow alias `experiment-latest` points at the most recently registered experiment version. +7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. -Training scripts can include a `training_metrics.json` file in the SageMaker model directory. During finalization, the CLI logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact: +Training scripts can include a `training_metrics.json` file in the SageMaker model directory. The explicit metrics +upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores +the JSON as a run artifact: ```json { @@ -239,7 +248,9 @@ Training scripts can include a `training_metrics.json` file in the SageMaker mod } ``` -Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. Missing or malformed metrics produce a warning but do not block model registration. +Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and +strictly increasing. A missing or malformed metrics artifact fails the upload command without affecting the trained +model or model registration. Future release aliases such as `v1` or `production` can point at a selected deployable artifact. diff --git a/examples/meter-detection/README.md b/examples/meter-detection/README.md index e6d5c27..17a441a 100644 --- a/examples/meter-detection/README.md +++ b/examples/meter-detection/README.md @@ -156,11 +156,17 @@ qc-cli train status qc-cli-YYYYMMDD-HHMMSS To submit the job, wait for completion, and automatically import metrics and register the model, run: ```bash -qc-cli train start --wait +qc-cli train start --upload-metrics ``` The default polling interval is 30 seconds. It can be changed with `--poll-interval `. +The metrics can be also submitted using: + +```bash +qc-cli mlflow upload-metrics +``` + ## SageMaker Outputs When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`. @@ -176,7 +182,9 @@ training_metrics.json The archive is stored under the configured `s3.model_prefix`. -During MLflow finalization, `training_metrics.json` provides per-epoch training and validation losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall are more meaningful than classification accuracy when assessing model quality. +The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation +losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall +are more meaningful than classification accuracy when assessing model quality. ## 6. Configure Qualcomm AI Hub diff --git a/src/commands/mlflow.py b/src/commands/mlflow.py index 8fd3ef2..401a3d4 100644 --- a/src/commands/mlflow.py +++ b/src/commands/mlflow.py @@ -2,8 +2,11 @@ import webbrowser import typer +from src import state as state_ops from src.aws import mlflow as aws_mlflow from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg +from src.config import MlflowMode +from src.tracking.upload import upload_training_metrics app = typer.Typer(help="Manage MLflow tracking server access") @@ -39,3 +42,43 @@ def open_mlflow(config: str = CONFIG_OPT) -> None: CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.") else: CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]") + + +@app.command(name="upload-metrics") +def upload_metrics( + job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), + force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"), + config: str = CONFIG_OPT, +) -> None: + """Upload a completed training job's metric history to MLflow.""" + cfg = load_cfg(config) + if cfg.mlflow.mode is MlflowMode.disabled: + CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]") + raise typer.Exit(1) + + st = state_ops.store(config) + if not job_name: + job_name = st.get_last_training_job() + if not job_name: + CONSOLE.print( + "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" + ) + raise typer.Exit(1) + + if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force: + CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].") + return + + try: + run_id = upload_training_metrics( + job_name=job_name, + config_path=config, + cfg=cfg, + force=force, + ) + except Exception as e: + CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") + raise typer.Exit(1) + + CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") + CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") diff --git a/src/commands/train.py b/src/commands/train.py index 3c38927..a6c3c1a 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -12,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.config import Config, MlflowMode from src.infra.state import read_infra_state from src.tracking.mlflow import MlflowTracker +from src.tracking.upload import upload_training_metrics app = typer.Typer(help="Manage SageMaker training jobs") @@ -102,7 +103,7 @@ def _finalize_terminal_job( ) -def _wait_for_training_job( +def _wait_and_upload_metrics( *, job_name: str, poll_interval: int, @@ -123,12 +124,21 @@ def _wait_for_training_job( previous_status = training_status.status if training_status.status in _TERMINAL_STATUSES: _print_training_status(training_status) - _finalize_terminal_job( - config_path=config_path, - cfg=cfg, - status=training_status, - command="train start --wait", - ) + if training_status.status != "Completed": + raise typer.Exit(1) + try: + run_id = upload_training_metrics( + job_name=job_name, + config_path=config_path, + cfg=cfg, + ) + CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].") + except Exception as e: + CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") + CONSOLE.print( + f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]." + ) + raise typer.Exit(1) job_state = st.get_training_job(job_name) if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") @@ -141,18 +151,26 @@ def _wait_for_training_job( @app.command() def start( - wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"), + upload_metrics: bool = typer.Option( + False, + "--upload-metrics", + help="Wait for completion, then upload training metrics to MLflow", + ), poll_interval: int = typer.Option( DEFAULT_POLL_INTERVAL_SECONDS, "--poll-interval", min=1, - help="Seconds between status checks when --wait is used", + help="Seconds between status checks when --upload-metrics is used", ), config: str = CONFIG_OPT, ) -> None: """Submit a SageMaker training job.""" cfg = load_cfg(config) + if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled: + CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]") + raise typer.Exit(1) + if not cfg.sagemaker.training.image_uri: CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]") CONSOLE.print( @@ -189,12 +207,20 @@ def start( st = state_ops.store(config) st.set_last_training_job(job_name) - run_id = tracker.start_training_run( - training_job, - region=cfg.aws.region, - profile=cfg.aws.profile, - role_arn=role_arn, - ) + try: + run_id = tracker.start_training_run( + training_job, + region=cfg.aws.region, + profile=cfg.aws.profile, + role_arn=role_arn, + ) + except Exception as e: + run_id = None + CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]") + CONSOLE.print( + "The SageMaker job is still running. Upload metrics after completion with " + f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]." + ) if run_id: st.update_training_job(job_name, mlflow_run_id=run_id) @@ -202,8 +228,8 @@ def start( if run_id: CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") - if wait: - _wait_for_training_job( + if upload_metrics: + _wait_and_upload_metrics( job_name=job_name, poll_interval=poll_interval, config_path=config, diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index e125ae8..ad46e37 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -31,6 +31,17 @@ class Tracker(Protocol): command: str, ) -> FinalizeResult: ... + def ensure_training_run(self, job_name: str) -> str: ... + + def upload_training_metrics( + self, + *, + run_id: str, + training_job_status: Any, + region: str, + profile: str, + ) -> None: ... + @dataclass(frozen=True) class NoopTracker: @@ -48,6 +59,19 @@ class NoopTracker: ) -> FinalizeResult: return FinalizeResult() + def ensure_training_run(self, job_name: str) -> str: + raise RuntimeError("MLflow is disabled.") + + def upload_training_metrics( + self, + *, + run_id: str, + training_job_status: Any, + region: str, + profile: str, + ) -> None: + raise RuntimeError("MLflow is disabled.") + @dataclass(frozen=True) class MlflowTracker: @@ -73,7 +97,6 @@ class MlflowTracker: tracking_server_name, ) mlflow.set_tracking_uri(tracking_uri) - mlflow.set_experiment(cfg.mlflow.experiment_name) return cls( tracking_uri=tracking_uri, @@ -83,34 +106,33 @@ class MlflowTracker: ) def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: - run = mlflow.start_run(run_name=training_job.job_name) - run_id = str(run.info.run_id) - - params = { - "aws.region": region, - "aws.profile": profile, - "sagemaker.role_arn": role_arn, - "sagemaker.job_name": training_job.job_name, - "sagemaker.training_image": training_job.image_uri, - "sagemaker.instance_type": training_job.instance_type, - "sagemaker.instance_count": training_job.instance_count, - "sagemaker.s3_train_uri": training_job.s3_train_uri, - "sagemaker.s3_output_path": training_job.s3_output_path, - "sagemaker.entry_point": training_job.entry_point, - "sagemaker.source_dir": training_job.source_dir, - } - self._log_params(params) - self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) - mlflow.set_tags( - { - "qc_cli.stage": "experiment", - "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "qc_cli.command": "train start", + mlflow.set_experiment(self.experiment_name) + with mlflow.start_run(run_name=training_job.job_name) as run: + run_id = str(run.info.run_id) + params = { + "aws.region": region, + "aws.profile": profile, + "sagemaker.role_arn": role_arn, "sagemaker.job_name": training_job.job_name, + "sagemaker.training_image": training_job.image_uri, + "sagemaker.instance_type": training_job.instance_type, + "sagemaker.instance_count": training_job.instance_count, + "sagemaker.s3_train_uri": training_job.s3_train_uri, + "sagemaker.s3_output_path": training_job.s3_output_path, + "sagemaker.entry_point": training_job.entry_point, + "sagemaker.source_dir": training_job.source_dir, } - ) - mlflow.end_run() + self._log_params(params) + self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) + mlflow.set_tags( + { + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": "sagemaker", + "qc_cli.command": "train start", + "sagemaker.job_name": training_job.job_name, + } + ) return run_id def finalize_training_run( @@ -125,7 +147,6 @@ class MlflowTracker: if not run_id: return FinalizeResult() - warnings: list[str] = [] with mlflow.start_run(run_id=run_id): self._log_params( { @@ -137,22 +158,14 @@ class MlflowTracker: } ) self._log_final_metrics(training_job_status.raw) - if training_job_status.status == "Completed" and training_job_status.model_artifacts: - warnings.extend( - self._log_training_metrics( - training_job_status.model_artifacts, - region=region, - profile=profile, - ) - ) mlflow.set_tag("qc_cli.command", command) if training_job_status.status != "Completed" or not training_job_status.model_artifacts: mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) - return FinalizeResult(warnings=tuple(warnings)) + return FinalizeResult() if not self.register_trained_models: - return FinalizeResult(warnings=tuple(warnings)) + return FinalizeResult() client = MlflowClient() self._ensure_registered_model(client, self.registered_model_name) @@ -171,7 +184,61 @@ class MlflowTracker: client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_version", version_number) - return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings)) + return FinalizeResult(registered_model_version=version_number) + + def ensure_training_run(self, job_name: str) -> str: + client = MlflowClient() + experiment = client.get_experiment_by_name(self.experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(self.experiment_name) + else: + experiment_id = experiment.experiment_id + + for run in client.search_runs([experiment_id], max_results=1000): + if run.data.tags.get("sagemaker.job_name") == job_name: + return str(run.info.run_id) + + run = client.create_run( + experiment_id, + run_name=job_name, + tags={ + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": "sagemaker", + "qc_cli.command": "mlflow upload-metrics", + "sagemaker.job_name": job_name, + }, + ) + return str(run.info.run_id) + + def upload_training_metrics( + self, + *, + run_id: str, + training_job_status: Any, + region: str, + profile: str, + ) -> None: + if not training_job_status.model_artifacts: + raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.") + + with mlflow.start_run(run_id=run_id): + self._log_params( + { + "sagemaker.training_status": training_job_status.status, + "sagemaker.created_at": training_job_status.created, + "sagemaker.modified_at": training_job_status.modified, + "sagemaker.model_artifacts": training_job_status.model_artifacts, + "sagemaker.failure_reason": training_job_status.failure_reason, + } + ) + self._log_final_metrics(training_job_status.raw) + self._log_training_metrics( + training_job_status.model_artifacts, + region=region, + profile=profile, + ) + mlflow.set_tag("qc_cli.command", "mlflow upload-metrics") def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} @@ -188,28 +255,24 @@ class MlflowTracker: if metrics: mlflow.log_metrics(metrics) - def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]: - try: - with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: - archive_path = s3.download_file( - region, - profile, - model_artifacts, - os.path.join(temp_dir, "model.tar.gz"), - ) - metrics_data = read_training_metrics_from_tar(archive_path) - if metrics_data is None: - return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."] - metrics = parse_training_metrics(metrics_data) - for metric_step in metrics.steps: - if metric_step.metrics: - mlflow.log_metrics(metric_step.metrics, step=metric_step.step) - if metrics.summary: - mlflow.log_metrics(metrics.summary) - mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) - except Exception as exc: - return [f"Could not import training metrics: {exc}"] - return [] + def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None: + with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: + archive_path = s3.download_file( + region, + profile, + model_artifacts, + os.path.join(temp_dir, "model.tar.gz"), + ) + metrics_data = read_training_metrics_from_tar(archive_path) + if metrics_data is None: + raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.") + metrics = parse_training_metrics(metrics_data) + for metric_step in metrics.steps: + if metric_step.metrics: + mlflow.log_metrics(metric_step.metrics, step=metric_step.step) + if metrics.summary: + mlflow.log_metrics(metrics.summary) + mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: try: diff --git a/src/tracking/upload.py b/src/tracking/upload.py new file mode 100644 index 0000000..add9d98 --- /dev/null +++ b/src/tracking/upload.py @@ -0,0 +1,38 @@ +from src import state as state_ops +from src.aws import sagemaker as sm_ops +from src.config import Config, MlflowMode +from src.tracking.mlflow import MlflowTracker + + +def upload_training_metrics( + *, + job_name: str, + config_path: str, + cfg: Config, + force: bool = False, +) -> str: + if cfg.mlflow.mode is MlflowMode.disabled: + raise RuntimeError("MLflow is disabled in config.yaml.") + + st = state_ops.store(config_path) + job_state = st.get_training_job(job_name) + if job_state.get("mlflow_metrics_uploaded") and not force: + return str(job_state.get("mlflow_run_id") or "") + + status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + if status.status != "Completed": + raise RuntimeError( + f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion." + ) + + tracker = MlflowTracker.from_config(cfg) + run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name)) + st.update_training_job(job_name, mlflow_run_id=run_id) + tracker.upload_training_metrics( + run_id=run_id, + training_job_status=status, + region=cfg.aws.region, + profile=cfg.aws.profile, + ) + st.update_training_job(job_name, mlflow_metrics_uploaded=True) + return run_id -- 2.49.1 From 4c33a016f0b7243e85df5d7c426c9d5ac8823581 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 12:21:41 -0400 Subject: [PATCH 5/6] simplify --- README.md | 10 +++++--- src/commands/mlflow.py | 9 +++++-- src/commands/train.py | 56 ++++++++---------------------------------- src/tracking/upload.py | 41 ++++++++++++++++++++++++++++--- 4 files changed, 61 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index afb8dd7..d750381 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,11 @@ mlflow: tracking_server_name: your-tracking-server-name ``` -When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release. +When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through +`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts +as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only. +An experiment version is an immutable trained-source artifact; it records that training produced a model, not that +the model is better than earlier versions or ready for release. To open the managed SageMaker MLflow UI, request a fresh presigned URL: @@ -224,10 +228,10 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts. +2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow. 3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion. 4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job. -5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: +5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` - `qc_cli.source=sagemaker` diff --git a/src/commands/mlflow.py b/src/commands/mlflow.py index 401a3d4..f282e2f 100644 --- a/src/commands/mlflow.py +++ b/src/commands/mlflow.py @@ -70,7 +70,7 @@ def upload_metrics( return try: - run_id = upload_training_metrics( + result = upload_training_metrics( job_name=job_name, config_path=config, cfg=cfg, @@ -81,4 +81,9 @@ def upload_metrics( raise typer.Exit(1) CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") - CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") + CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]") + if result.registered_model_version: + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) diff --git a/src/commands/train.py b/src/commands/train.py index a6c3c1a..4fe0b1c 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -64,45 +64,6 @@ def _print_training_status(status: sm_ops.TrainingJobStatus) -> None: CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") -def _finalize_terminal_job( - *, - config_path: str, - cfg: Config, - status: sm_ops.TrainingJobStatus, - command: str, -) -> None: - if status.status not in _TERMINAL_STATUSES: - return - - st = state_ops.store(config_path) - job_state = st.get_training_job(status.name) - run_id = job_state.get("mlflow_run_id") - if not run_id or job_state.get("mlflow_finalized_status"): - return - - tracker = _tracker(cfg) - result = tracker.finalize_training_run( - run_id=str(run_id), - training_job_status=status, - region=cfg.aws.region, - profile=cfg.aws.profile, - command=command, - ) - updates = {"mlflow_finalized_status": status.status} - if result.registered_model_version: - updates["registered_model_version"] = result.registered_model_version - st.update_training_job(status.name, **updates) - - for warning in result.warnings: - CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]") - if result.registered_model_version: - st.set_latest_experiment_model_version(result.registered_model_version) - CONSOLE.print( - f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " - "([cyan]experiment-latest[/cyan])" - ) - - def _wait_and_upload_metrics( *, job_name: str, @@ -127,12 +88,20 @@ def _wait_and_upload_metrics( if training_status.status != "Completed": raise typer.Exit(1) try: - run_id = upload_training_metrics( + result = upload_training_metrics( job_name=job_name, config_path=config_path, cfg=cfg, ) - CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].") + CONSOLE.print( + f"[green]✓[/green] Uploaded training metrics to MLflow run " + f"[cyan]{result.run_id}[/cyan]." + ) + if result.registered_model_version: + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) except Exception as e: CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") CONSOLE.print( @@ -258,11 +227,6 @@ def status( status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) _print_training_status(status) - _finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status") - - job_state = st.get_training_job(job_name) - if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: - CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") @app.command(name="list") diff --git a/src/tracking/upload.py b/src/tracking/upload.py index add9d98..78e6a41 100644 --- a/src/tracking/upload.py +++ b/src/tracking/upload.py @@ -1,23 +1,38 @@ +from dataclasses import dataclass + from src import state as state_ops from src.aws import sagemaker as sm_ops from src.config import Config, MlflowMode from src.tracking.mlflow import MlflowTracker +@dataclass(frozen=True) +class MetricsUploadResult: + run_id: str + registered_model_version: str | None = None + + def upload_training_metrics( *, job_name: str, config_path: str, cfg: Config, force: bool = False, -) -> str: +) -> MetricsUploadResult: if cfg.mlflow.mode is MlflowMode.disabled: raise RuntimeError("MLflow is disabled in config.yaml.") st = state_ops.store(config_path) job_state = st.get_training_job(job_name) if job_state.get("mlflow_metrics_uploaded") and not force: - return str(job_state.get("mlflow_run_id") or "") + return MetricsUploadResult( + run_id=str(job_state.get("mlflow_run_id") or ""), + registered_model_version=( + str(job_state["registered_model_version"]) + if job_state.get("registered_model_version") + else None + ), + ) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) if status.status != "Completed": @@ -34,5 +49,23 @@ def upload_training_metrics( region=cfg.aws.region, profile=cfg.aws.profile, ) - st.update_training_job(job_name, mlflow_metrics_uploaded=True) - return run_id + finalized = tracker.finalize_training_run( + run_id=run_id, + training_job_status=status, + region=cfg.aws.region, + profile=cfg.aws.profile, + command="mlflow upload-metrics", + ) + updates = { + "mlflow_metrics_uploaded": True, + "mlflow_finalized_status": status.status, + } + if finalized.registered_model_version: + updates["registered_model_version"] = finalized.registered_model_version + st.update_training_job(job_name, **updates) + if finalized.registered_model_version: + st.set_latest_experiment_model_version(finalized.registered_model_version) + return MetricsUploadResult( + run_id=run_id, + registered_model_version=finalized.registered_model_version, + ) -- 2.49.1 From 20cd3f979491e02e6003424e3ede4cb0f23803f6 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 14:10:52 -0400 Subject: [PATCH 6/6] update --- README.md | 11 ++++++----- src/commands/mlflow.py | 8 +++++++- src/commands/train.py | 14 ++++++++++---- src/tracking/mlflow.py | 15 +++++++++------ src/tracking/upload.py | 6 +++++- 5 files changed, 37 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index d750381..cbb31bb 100644 --- a/README.md +++ b/README.md @@ -238,9 +238,9 @@ Current behavior: 6. The MLflow alias `experiment-latest` points at the most recently registered experiment version. 7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. -Training scripts can include a `training_metrics.json` file in the SageMaker model directory. The explicit metrics -upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores -the JSON as a run artifact: +Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the +explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow +step and stores the JSON as a run artifact: ```json { @@ -253,8 +253,9 @@ the JSON as a run artifact: ``` Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and -strictly increasing. A missing or malformed metrics artifact fails the upload command without affecting the trained -model or model registration. +strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues +model registration without per-epoch history. A malformed metrics artifact still fails the upload command without +affecting the trained model or model registration. Future release aliases such as `v1` or `production` can point at a selected deployable artifact. diff --git a/src/commands/mlflow.py b/src/commands/mlflow.py index f282e2f..14d2f57 100644 --- a/src/commands/mlflow.py +++ b/src/commands/mlflow.py @@ -80,7 +80,13 @@ def upload_metrics( CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") raise typer.Exit(1) - CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") + if result.metrics_history_uploaded: + CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") + else: + CONSOLE.print( + f"[yellow]No training_metrics.json was found in the SageMaker model artifact for " + f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]" + ) CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]") if result.registered_model_version: CONSOLE.print( diff --git a/src/commands/train.py b/src/commands/train.py index 4fe0b1c..31356d3 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -93,10 +93,16 @@ def _wait_and_upload_metrics( config_path=config_path, cfg=cfg, ) - CONSOLE.print( - f"[green]✓[/green] Uploaded training metrics to MLflow run " - f"[cyan]{result.run_id}[/cyan]." - ) + if result.metrics_history_uploaded: + CONSOLE.print( + f"[green]✓[/green] Uploaded training metrics to MLflow run " + f"[cyan]{result.run_id}[/cyan]." + ) + else: + CONSOLE.print( + "[yellow]No training_metrics.json was found in the SageMaker model artifact. " + "Uploaded SageMaker final metrics only.[/yellow]" + ) if result.registered_model_version: CONSOLE.print( f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index 6583cf9..287e259 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -40,7 +40,7 @@ class Tracker(Protocol): training_job_status: Any, region: str, profile: str, - ) -> None: ... + ) -> bool: ... @dataclass(frozen=True) @@ -69,7 +69,7 @@ class NoopTracker: training_job_status: Any, region: str, profile: str, - ) -> None: + ) -> bool: raise RuntimeError("MLflow is disabled.") @@ -208,7 +208,7 @@ class MlflowTracker: training_job_status: Any, region: str, profile: str, - ) -> None: + ) -> bool: if not training_job_status.model_artifacts: raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.") @@ -216,12 +216,14 @@ class MlflowTracker: with mlflow.start_run(run_id=run_id): self._log_params(self.tracking_backend.training_status_params(training_job_status)) self._log_final_metrics(training_job_status.raw) - self._log_training_metrics( + history_uploaded = self._log_training_metrics( training_job_status.model_artifacts, region=region, profile=profile, ) mlflow.set_tag("qc_cli.command", "mlflow upload-metrics") + mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower()) + return history_uploaded def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} @@ -238,7 +240,7 @@ class MlflowTracker: if metrics: mlflow.log_metrics(metrics) - def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None: + def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool: with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: archive_path = s3.download_file( region, @@ -248,7 +250,7 @@ class MlflowTracker: ) metrics_data = read_training_metrics_from_tar(archive_path) if metrics_data is None: - raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.") + return False metrics = parse_training_metrics(metrics_data) for metric_step in metrics.steps: if metric_step.metrics: @@ -256,6 +258,7 @@ class MlflowTracker: if metrics.summary: mlflow.log_metrics(metrics.summary) mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) + return True def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: try: diff --git a/src/tracking/upload.py b/src/tracking/upload.py index 78e6a41..5cf77a3 100644 --- a/src/tracking/upload.py +++ b/src/tracking/upload.py @@ -10,6 +10,7 @@ from src.tracking.mlflow import MlflowTracker class MetricsUploadResult: run_id: str registered_model_version: str | None = None + metrics_history_uploaded: bool = True def upload_training_metrics( @@ -32,6 +33,7 @@ def upload_training_metrics( if job_state.get("registered_model_version") else None ), + metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)), ) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) @@ -43,7 +45,7 @@ def upload_training_metrics( tracker = MlflowTracker.from_config(cfg) run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name)) st.update_training_job(job_name, mlflow_run_id=run_id) - tracker.upload_training_metrics( + metrics_history_uploaded = tracker.upload_training_metrics( run_id=run_id, training_job_status=status, region=cfg.aws.region, @@ -58,6 +60,7 @@ def upload_training_metrics( ) updates = { "mlflow_metrics_uploaded": True, + "mlflow_metrics_history_uploaded": metrics_history_uploaded, "mlflow_finalized_status": status.status, } if finalized.registered_model_version: @@ -68,4 +71,5 @@ def upload_training_metrics( return MetricsUploadResult( run_id=run_id, registered_model_version=finalized.registered_model_version, + metrics_history_uploaded=metrics_history_uploaded, ) -- 2.49.1