diff --git a/README.md b/README.md index 3a61e88..cbb31bb 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,11 @@ mlflow: tracking_server_name: your-tracking-server-name ``` -When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release. +When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through +`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts +as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only. +An experiment version is an immutable trained-source artifact; it records that training produced a model, not that +the model is better than earlier versions or ready for release. To open the managed SageMaker MLflow UI, request a fresh presigned URL: @@ -128,9 +132,14 @@ qc-cli init --force Overwrite an existing config file ### `mlflow` ``` -qc-cli mlflow open Open a presigned MLflow UI URL in a browser +qc-cli mlflow open Open a presigned MLflow UI URL +qc-cli mlflow upload-metrics [job-name] Upload completed training metrics ``` +`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run, +imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`. +Use `--force` to upload the metrics again. + ### `infra` ``` @@ -163,6 +172,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de ``` qc-cli train start Submit a SageMaker training job +qc-cli train start --upload-metrics Submit, wait, and upload metrics qc-cli train status [job-name] Show job status; defaults to the last submitted job qc-cli train list List recent training jobs qc-cli train list --limit 3 Show a custom number of recent jobs @@ -170,6 +180,8 @@ qc-cli train list --limit 3 Show a custom number of recent jobs `train start` uses `s3:////` as the training channel and writes outputs under `s3:////`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. +`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. + The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. ### `ai-hub` @@ -216,13 +228,34 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state. -3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: +2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow. +3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion. +4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job. +5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` - `qc_cli.source=sagemaker` -4. The MLflow alias `experiment-latest` points at the most recently registered experiment version. -5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. +6. The MLflow alias `experiment-latest` points at the most recently registered experiment version. +7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. + +Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the +explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow +step and stores the JSON as a run artifact: + +```json +{ + "schema_version": 1, + "steps": [ + {"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}} + ], + "summary": {"summary.best_epoch": 0} +} +``` + +Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and +strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues +model registration without per-epoch history. A malformed metrics artifact still fails the upload command without +affecting the trained model or model registration. Future release aliases such as `v1` or `production` can point at a selected deployable artifact. diff --git a/examples/meter-detection/README.md b/examples/meter-detection/README.md index 35955df..17a441a 100644 --- a/examples/meter-detection/README.md +++ b/examples/meter-detection/README.md @@ -153,6 +153,20 @@ Or pass the job name explicitly: qc-cli train status qc-cli-YYYYMMDD-HHMMSS ``` +To submit the job, wait for completion, and automatically import metrics and register the model, run: + +```bash +qc-cli train start --upload-metrics +``` + +The default polling interval is 30 seconds. It can be changed with `--poll-interval `. + +The metrics can be also submitted using: + +```bash +qc-cli mlflow upload-metrics +``` + ## SageMaker Outputs When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`. @@ -163,10 +177,15 @@ This example writes: best.pt model.onnx metrics.json +training_metrics.json ``` The archive is stored under the configured `s3.model_prefix`. +The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation +losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall +are more meaningful than classification accuracy when assessing model quality. + ## 6. Configure Qualcomm AI Hub Authenticate with Qualcomm AI Hub: diff --git a/examples/meter-detection/source/train.py b/examples/meter-detection/source/train.py index 99aef51..f40b872 100644 --- a/examples/meter-detection/source/train.py +++ b/examples/meter-detection/source/train.py @@ -12,6 +12,7 @@ from typing import Any import yaml from sanitize_onnx import sanitize_onnx +from training_metrics import write_training_metrics from ultralytics import YOLO # type: ignore[reportMissingImports] @@ -101,6 +102,7 @@ def main() -> None: if not trained_weights.exists(): raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}") + write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json") copy_if_exists(trained_weights, model_dir / "best.pt") trained_model = YOLO(str(trained_weights)) onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz)) diff --git a/examples/meter-detection/source/training_metrics.py b/examples/meter-detection/source/training_metrics.py new file mode 100644 index 0000000..004c557 --- /dev/null +++ b/examples/meter-detection/source/training_metrics.py @@ -0,0 +1,82 @@ +import csv +import json +import math +import re +from pathlib import Path +from typing import Any + +METRIC_NAMES = { + "metrics/precision(B)": "val.precision", + "metrics/recall(B)": "val.recall", + "metrics/mAP50(B)": "val.map50", + "metrics/mAP50-95(B)": "val.map50_95", + "train/box_loss": "train.box_loss", + "train/cls_loss": "train.cls_loss", + "train/dfl_loss": "train.dfl_loss", + "val/box_loss": "val.box_loss", + "val/cls_loss": "val.cls_loss", + "val/dfl_loss": "val.dfl_loss", + "time": "train.elapsed_seconds", +} + + +def write_training_metrics(results_csv: Path, destination: Path) -> None: + steps = _read_metric_steps(results_csv) + summary = _build_summary(steps) + payload = { + "schema_version": 1, + "steps": steps, + "summary": summary, + } + destination.write_text(json.dumps(payload, indent=2), encoding="utf-8") + print(f"Saved {destination}") + + +def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]: + if not results_csv.is_file(): + raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}") + + steps: list[dict[str, Any]] = [] + with results_csv.open(newline="", encoding="utf-8") as csv_file: + for row_index, raw_row in enumerate(csv.DictReader(csv_file)): + row = {str(key).strip(): value for key, value in raw_row.items()} + raw_epoch = row.pop("epoch", row_index) + step = int(float(raw_epoch)) + metrics: dict[str, float] = {} + for source_name, raw_value in row.items(): + if raw_value is None or not raw_value.strip(): + continue + try: + value = float(raw_value) + except ValueError: + continue + if math.isfinite(value): + metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value + steps.append({"step": step, "metrics": metrics}) + return steps + + +def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]: + if not steps: + return {} + + summary: dict[str, float] = {} + final_step = steps[-1] + summary["summary.final_epoch"] = float(final_step["step"]) + for name, value in final_step["metrics"].items(): + summary[f"summary.final.{name}"] = value + + scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]] + if scored_steps: + best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"]) + summary["summary.best_epoch"] = float(best_step["step"]) + summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"] + if "val.map50" in best_step["metrics"]: + summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"] + return summary + + +def _normalize_metric_name(name: str) -> str: + normalized = name.replace("/", ".") + normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized) + return normalized.strip("._") or "unnamed" diff --git a/src/aws/mlflow.py b/src/aws/mlflow.py index 35433bb..344e70f 100644 --- a/src/aws/mlflow.py +++ b/src/aws/mlflow.py @@ -1,3 +1,6 @@ +import os +from collections.abc import Generator +from contextlib import contextmanager from typing import Any, cast import boto3 @@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) - client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) return str(response["AuthorizedUrl"]) + + +@contextmanager +def tracking_auth_env(profile: str, region: str) -> Generator[None]: + credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials() + if credentials is None: + raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.") + + frozen_credentials = credentials.get_frozen_credentials() + if not frozen_credentials.access_key or not frozen_credentials.secret_key: + raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.") + + env_updates = { + "AWS_PROFILE": profile, + "AWS_DEFAULT_REGION": region, + "AWS_REGION": region, + "AWS_ACCESS_KEY_ID": frozen_credentials.access_key, + "AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key, + } + if frozen_credentials.token: + env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token + + restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"} + previous_env = {key: os.environ.get(key) for key in restore_keys} + try: + os.environ.update(env_updates) + if not frozen_credentials.token: + os.environ.pop("AWS_SESSION_TOKEN", None) + yield + finally: + for key, value in previous_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/src/cloud/__init__.py b/src/cloud/__init__.py index e69de29..75bf900 100644 --- a/src/cloud/__init__.py +++ b/src/cloud/__init__.py @@ -0,0 +1 @@ +"""Cloud provider adapters.""" diff --git a/src/cloud/mlflow.py b/src/cloud/mlflow.py new file mode 100644 index 0000000..08ae203 --- /dev/null +++ b/src/cloud/mlflow.py @@ -0,0 +1,77 @@ +from contextlib import AbstractContextManager +from dataclasses import dataclass +from typing import Any, Protocol + +from src.aws import mlflow as aws_mlflow +from src.config import Config + + +class MlflowTrackingBackend(Protocol): + @property + def provider_name(self) -> str: ... + + @property + def profile(self) -> str: ... + + @property + def region(self) -> str: ... + + def get_tracking_uri(self, tracking_server_name: str) -> str: ... + + def auth_env(self) -> AbstractContextManager[None]: ... + + def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ... + + def training_run_tags(self, training_job: Any) -> dict[str, Any]: ... + + def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ... + + def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ... + + +@dataclass(frozen=True) +class AwsMlflowTrackingBackend: + profile: str + region: str + provider_name: str = "aws" + + def get_tracking_uri(self, tracking_server_name: str) -> str: + return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name) + + def auth_env(self) -> AbstractContextManager[None]: + return aws_mlflow.tracking_auth_env(self.profile, self.region) + + def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: + return { + "provider.name": self.provider_name, + "provider.region": region, + "provider.profile": profile, + "sagemaker.role_arn": role_arn, + "sagemaker.job_name": training_job.job_name, + "sagemaker.training_image": training_job.image_uri, + "sagemaker.instance_type": training_job.instance_type, + "sagemaker.instance_count": training_job.instance_count, + "sagemaker.s3_train_uri": training_job.s3_train_uri, + "sagemaker.s3_output_path": training_job.s3_output_path, + "sagemaker.entry_point": training_job.entry_point, + "sagemaker.source_dir": training_job.source_dir, + } + + def training_run_tags(self, training_job: Any) -> dict[str, Any]: + return {"sagemaker.job_name": training_job.job_name} + + def training_status_params(self, training_job_status: Any) -> dict[str, Any]: + return { + "sagemaker.training_status": training_job_status.status, + "sagemaker.created_at": training_job_status.created, + "sagemaker.modified_at": training_job_status.modified, + "sagemaker.model_artifacts": training_job_status.model_artifacts, + "sagemaker.failure_reason": training_job_status.failure_reason, + } + + def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: + return {"sagemaker.job_name": training_job_status.name} + + +def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend: + return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region) diff --git a/src/commands/mlflow.py b/src/commands/mlflow.py index 8fd3ef2..14d2f57 100644 --- a/src/commands/mlflow.py +++ b/src/commands/mlflow.py @@ -2,8 +2,11 @@ import webbrowser import typer +from src import state as state_ops from src.aws import mlflow as aws_mlflow from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg +from src.config import MlflowMode +from src.tracking.upload import upload_training_metrics app = typer.Typer(help="Manage MLflow tracking server access") @@ -39,3 +42,54 @@ def open_mlflow(config: str = CONFIG_OPT) -> None: CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.") else: CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]") + + +@app.command(name="upload-metrics") +def upload_metrics( + job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), + force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"), + config: str = CONFIG_OPT, +) -> None: + """Upload a completed training job's metric history to MLflow.""" + cfg = load_cfg(config) + if cfg.mlflow.mode is MlflowMode.disabled: + CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]") + raise typer.Exit(1) + + st = state_ops.store(config) + if not job_name: + job_name = st.get_last_training_job() + if not job_name: + CONSOLE.print( + "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" + ) + raise typer.Exit(1) + + if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force: + CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].") + return + + try: + result = upload_training_metrics( + job_name=job_name, + config_path=config, + cfg=cfg, + force=force, + ) + except Exception as e: + CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") + raise typer.Exit(1) + + if result.metrics_history_uploaded: + CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") + else: + CONSOLE.print( + f"[yellow]No training_metrics.json was found in the SageMaker model artifact for " + f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]" + ) + CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]") + if result.registered_model_version: + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) diff --git a/src/commands/train.py b/src/commands/train.py index eb96c92..31356d3 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from pathlib import Path @@ -11,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.config import Config, MlflowMode from src.infra.state import read_infra_state from src.tracking.mlflow import MlflowTracker +from src.tracking.upload import upload_training_metrics app = typer.Typer(help="Manage SageMaker training jobs") @@ -21,6 +23,8 @@ _STATUS_COLOR = { "Stopping": "yellow", "Stopped": "dim", } +_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"} +DEFAULT_POLL_INTERVAL_SECONDS = 30 def _tracker(cfg): @@ -48,11 +52,100 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str: raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.") +def _print_training_status(status: sm_ops.TrainingJobStatus) -> None: + color = _STATUS_COLOR.get(status.status, "white") + CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]") + CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]") + if status.created: + CONSOLE.print(f"Created: {status.created}") + if status.model_artifacts: + CONSOLE.print(f"Artifacts: {status.model_artifacts}") + if status.failure_reason: + CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") + + +def _wait_and_upload_metrics( + *, + job_name: str, + poll_interval: int, + config_path: str, + cfg: Config, +) -> None: + st = state_ops.store(config_path) + previous_status: str | None = None + try: + while True: + training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + if training_status.status != previous_status: + color = _STATUS_COLOR.get(training_status.status, "white") + CONSOLE.print( + f"Job [cyan]{training_status.name}[/cyan]: " + f"[{color}]{training_status.status}[/{color}]" + ) + previous_status = training_status.status + if training_status.status in _TERMINAL_STATUSES: + _print_training_status(training_status) + if training_status.status != "Completed": + raise typer.Exit(1) + try: + result = upload_training_metrics( + job_name=job_name, + config_path=config_path, + cfg=cfg, + ) + if result.metrics_history_uploaded: + CONSOLE.print( + f"[green]✓[/green] Uploaded training metrics to MLflow run " + f"[cyan]{result.run_id}[/cyan]." + ) + else: + CONSOLE.print( + "[yellow]No training_metrics.json was found in the SageMaker model artifact. " + "Uploaded SageMaker final metrics only.[/yellow]" + ) + if result.registered_model_version: + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) + except Exception as e: + CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") + CONSOLE.print( + f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]." + ) + raise typer.Exit(1) + job_state = st.get_training_job(job_name) + if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: + CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") + return + time.sleep(poll_interval) + except KeyboardInterrupt: + CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") + raise typer.Exit(130) + + @app.command() -def start(config: str = CONFIG_OPT) -> None: +def start( + upload_metrics: bool = typer.Option( + False, + "--upload-metrics", + help="Wait for completion, then upload training metrics to MLflow", + ), + poll_interval: int = typer.Option( + DEFAULT_POLL_INTERVAL_SECONDS, + "--poll-interval", + min=1, + help="Seconds between status checks when --upload-metrics is used", + ), + config: str = CONFIG_OPT, +) -> None: """Submit a SageMaker training job.""" cfg = load_cfg(config) + if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled: + CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]") + raise typer.Exit(1) + if not cfg.sagemaker.training.image_uri: CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]") CONSOLE.print( @@ -89,12 +182,20 @@ def start(config: str = CONFIG_OPT) -> None: st = state_ops.store(config) st.set_last_training_job(job_name) - run_id = tracker.start_training_run( - training_job, - region=cfg.aws.region, - profile=cfg.aws.profile, - role_arn=role_arn, - ) + try: + run_id = tracker.start_training_run( + training_job, + region=cfg.aws.region, + profile=cfg.aws.profile, + role_arn=role_arn, + ) + except Exception as e: + run_id = None + CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]") + CONSOLE.print( + "The SageMaker job is still running. Upload metrics after completion with " + f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]." + ) if run_id: st.update_training_job(job_name, mlflow_run_id=run_id) @@ -102,7 +203,15 @@ def start(config: str = CONFIG_OPT) -> None: if run_id: CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") - CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") + if upload_metrics: + _wait_and_upload_metrics( + job_name=job_name, + poll_interval=poll_interval, + config_path=config, + cfg=cfg, + ) + else: + CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") @app.command() @@ -123,35 +232,7 @@ def status( raise typer.Exit(1) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) - color = _STATUS_COLOR.get(status.status, "white") - - CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]") - CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]") - if status.created: - CONSOLE.print(f"Created: {status.created}") - if status.model_artifacts: - CONSOLE.print(f"Artifacts: {status.model_artifacts}") - if status.failure_reason: - CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") - - job_state = st.get_training_job(job_name) - run_id = job_state.get("mlflow_run_id") - already_registered = job_state.get("registered_model_version") - if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}: - tracker = _tracker(cfg) - version = tracker.finalize_training_run( - run_id=str(run_id), - training_job_status=status, - ) - updates = {"mlflow_finalized_status": status.status} - if version: - updates["registered_model_version"] = version - st.update_training_job(job_name, **updates) - if version: - st.set_latest_experiment_model_version(version) - CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])") - if run_id and cfg.mlflow.mode is not MlflowMode.disabled: - CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") + _print_training_status(status) @app.command(name="list") diff --git a/src/tracking/__init__.py b/src/tracking/__init__.py index 931f89f..6e63597 100644 --- a/src/tracking/__init__.py +++ b/src/tracking/__init__.py @@ -1,3 +1,3 @@ -from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker +from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker -__all__ = ["MlflowTracker", "NoopTracker", "Tracker"] +__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"] diff --git a/src/tracking/metrics.py b/src/tracking/metrics.py new file mode 100644 index 0000000..bb73a80 --- /dev/null +++ b/src/tracking/metrics.py @@ -0,0 +1,93 @@ +import json +import math +import tarfile +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import Any + +METRICS_ARTIFACT_NAME = "training_metrics.json" +METRICS_SCHEMA_VERSION = 1 +MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024 + + +@dataclass(frozen=True) +class MetricStep: + step: int + metrics: dict[str, float] + + +@dataclass(frozen=True) +class TrainingMetrics: + steps: list[MetricStep] + summary: dict[str, float] + raw: dict[str, Any] + + +def parse_training_metrics(data: bytes) -> TrainingMetrics: + try: + value = json.loads(data) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc + if not isinstance(value, dict): + raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object") + if value.get("schema_version") != METRICS_SCHEMA_VERSION: + raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}") + + raw_steps = value.get("steps") + if not isinstance(raw_steps, list): + raise ValueError("training metrics 'steps' must be a list") + + steps: list[MetricStep] = [] + previous_step: int | None = None + for index, raw_step in enumerate(raw_steps): + if not isinstance(raw_step, dict): + raise ValueError(f"training metrics step {index} must be an object") + step = raw_step.get("step") + if isinstance(step, bool) or not isinstance(step, int) or step < 0: + raise ValueError(f"training metrics step {index} has an invalid 'step'") + if previous_step is not None and step <= previous_step: + raise ValueError("training metrics steps must be unique and strictly increasing") + metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}") + steps.append(MetricStep(step=step, metrics=metrics)) + previous_step = step + + summary = _numeric_metrics(value.get("summary", {}), "training metrics summary") + return TrainingMetrics(steps=steps, summary=summary, raw=value) + + +def read_training_metrics_from_tar(archive_path: str) -> bytes | None: + with tarfile.open(archive_path, mode="r:*") as archive: + matches = [ + member + for member in archive.getmembers() + if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME + ] + if not matches: + return None + if len(matches) > 1: + raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files") + if matches[0].size > MAX_METRICS_ARTIFACT_BYTES: + raise ValueError( + f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit" + ) + extracted = archive.extractfile(matches[0]) + if extracted is None: + raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive") + return extracted.read() + + +def _numeric_metrics(value: Any, context: str) -> dict[str, float]: + if not isinstance(value, dict): + raise ValueError(f"{context} 'metrics' must be an object") + + metrics: dict[str, float] = {} + for raw_name, raw_value in value.items(): + if not isinstance(raw_name, str) or not raw_name: + raise ValueError(f"{context} contains an invalid metric name") + if isinstance(raw_value, bool) or not isinstance(raw_value, int | float): + raise ValueError(f"{context} metric '{raw_name}' must be numeric") + metric_value = float(raw_value) + if not math.isfinite(metric_value): + raise ValueError(f"{context} metric '{raw_name}' must be finite") + metrics[raw_name] = metric_value + return metrics diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index 0e8f5d0..287e259 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -1,18 +1,46 @@ import os +import tempfile from dataclasses import dataclass from typing import Any, Protocol import mlflow from mlflow.tracking import MlflowClient -from src.aws import mlflow as aws_mlflow +from src.aws import s3 +from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config from src.config import Config, MlflowMode +from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar + + +@dataclass(frozen=True) +class FinalizeResult: + registered_model_version: str | None = None + warnings: tuple[str, ...] = () class Tracker(Protocol): def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ... - def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ... + def finalize_training_run( + self, + *, + run_id: str | None, + training_job_status: Any, + region: str, + profile: str, + command: str, + ) -> FinalizeResult: ... + + def ensure_training_run(self, job_name: str) -> str: ... + + def upload_training_metrics( + self, + *, + run_id: str, + training_job_status: Any, + region: str, + profile: str, + ) -> bool: ... @dataclass(frozen=True) @@ -20,8 +48,29 @@ class NoopTracker: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: return None - def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: - return None + def finalize_training_run( + self, + *, + run_id: str | None, + training_job_status: Any, + region: str, + profile: str, + command: str, + ) -> FinalizeResult: + return FinalizeResult() + + def ensure_training_run(self, job_name: str) -> str: + raise RuntimeError("MLflow is disabled.") + + def upload_training_metrics( + self, + *, + run_id: str, + training_job_status: Any, + region: str, + profile: str, + ) -> bool: + raise RuntimeError("MLflow is disabled.") @dataclass(frozen=True) @@ -30,6 +79,7 @@ class MlflowTracker: experiment_name: str registered_model_name: str register_trained_models: bool + tracking_backend: MlflowTrackingBackend @classmethod def from_config(cls, cfg: Config) -> Tracker: @@ -42,94 +92,138 @@ class MlflowTracker: if not tracking_server_name: raise RuntimeError("MLflow tracking server name could not be resolved.") - tracking_uri = aws_mlflow.get_tracking_server_arn( - cfg.aws.region, - cfg.aws.profile, - tracking_server_name, - ) - mlflow.set_tracking_uri(tracking_uri) - mlflow.set_experiment(cfg.mlflow.experiment_name) + tracking_backend = mlflow_tracking_backend_from_config(cfg) + + tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name) + with tracking_backend.auth_env(): + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(cfg.mlflow.experiment_name) return cls( tracking_uri=tracking_uri, experiment_name=cfg.mlflow.experiment_name, registered_model_name=cfg.mlflow.registered_model_name, register_trained_models=cfg.mlflow.register_trained_models, + tracking_backend=tracking_backend, ) def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: - run = mlflow.start_run(run_name=training_job.job_name) - run_id = str(run.info.run_id) + with self.tracking_backend.auth_env(): + with mlflow.start_run(run_name=training_job.job_name) as run: + run_id = str(run.info.run_id) + self._log_params( + self.tracking_backend.training_run_params( + training_job, + region=region, + profile=profile, + role_arn=role_arn, + ) + ) + self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) + mlflow.set_tags( + { + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": self.tracking_backend.provider_name, + "qc_cli.command": "train start", + **self.tracking_backend.training_run_tags(training_job), + } + ) + return run_id - params = { - "aws.region": region, - "aws.profile": profile, - "sagemaker.role_arn": role_arn, - "sagemaker.job_name": training_job.job_name, - "sagemaker.training_image": training_job.image_uri, - "sagemaker.instance_type": training_job.instance_type, - "sagemaker.instance_count": training_job.instance_count, - "sagemaker.s3_train_uri": training_job.s3_train_uri, - "sagemaker.s3_output_path": training_job.s3_output_path, - "sagemaker.entry_point": training_job.entry_point, - "sagemaker.source_dir": training_job.source_dir, - } - self._log_params(params) - self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) - mlflow.set_tags( - { - "qc_cli.stage": "experiment", - "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "qc_cli.command": "train start", - "sagemaker.job_name": training_job.job_name, - } - ) - mlflow.end_run() - return run_id - - def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: + def finalize_training_run( + self, + *, + run_id: str | None, + training_job_status: Any, + region: str, + profile: str, + command: str, + ) -> FinalizeResult: if not run_id: - return None + return FinalizeResult() - with mlflow.start_run(run_id=run_id): - self._log_params( - { - "sagemaker.training_status": training_job_status.status, - "sagemaker.created_at": training_job_status.created, - "sagemaker.modified_at": training_job_status.modified, - "sagemaker.model_artifacts": training_job_status.model_artifacts, - "sagemaker.failure_reason": training_job_status.failure_reason, - } - ) - self._log_final_metrics(training_job_status.raw) - mlflow.set_tag("qc_cli.command", "train status") + with self.tracking_backend.auth_env(): + with mlflow.start_run(run_id=run_id): + self._log_params(self.tracking_backend.training_status_params(training_job_status)) + self._log_final_metrics(training_job_status.raw) + mlflow.set_tag("qc_cli.command", command) - if training_job_status.status != "Completed" or not training_job_status.model_artifacts: - mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) - return None + if training_job_status.status != "Completed" or not training_job_status.model_artifacts: + mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) + return FinalizeResult() - if not self.register_trained_models: - return None + if not self.register_trained_models: + return FinalizeResult() + client = MlflowClient() + self._ensure_registered_model(client, self.registered_model_name) + version = client.create_model_version( + name=self.registered_model_name, + source=training_job_status.model_artifacts, + run_id=run_id, + tags={ + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": self.tracking_backend.provider_name, + **self.tracking_backend.model_version_tags(training_job_status), + }, + ) + version_number = str(version.version) + client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) + mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) + mlflow.set_tag("qc_cli.registered_model_version", version_number) + return FinalizeResult(registered_model_version=version_number) + + def ensure_training_run(self, job_name: str) -> str: + with self.tracking_backend.auth_env(): client = MlflowClient() - self._ensure_registered_model(client, self.registered_model_name) - version = client.create_model_version( - name=self.registered_model_name, - source=training_job_status.model_artifacts, - run_id=run_id, + experiment = client.get_experiment_by_name(self.experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(self.experiment_name) + else: + experiment_id = experiment.experiment_id + + for run in client.search_runs([experiment_id], max_results=1000): + if run.data.tags.get("sagemaker.job_name") == job_name: + return str(run.info.run_id) + + run = client.create_run( + experiment_id, + run_name=job_name, tags={ "qc_cli.stage": "experiment", "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "sagemaker.job_name": training_job_status.name, + "qc_cli.source": self.tracking_backend.provider_name, + "qc_cli.command": "mlflow upload-metrics", + "sagemaker.job_name": job_name, }, ) - version_number = str(version.version) - client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) - mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) - mlflow.set_tag("qc_cli.registered_model_version", version_number) - return version_number + return str(run.info.run_id) + + def upload_training_metrics( + self, + *, + run_id: str, + training_job_status: Any, + region: str, + profile: str, + ) -> bool: + if not training_job_status.model_artifacts: + raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.") + + with self.tracking_backend.auth_env(): + with mlflow.start_run(run_id=run_id): + self._log_params(self.tracking_backend.training_status_params(training_job_status)) + self._log_final_metrics(training_job_status.raw) + history_uploaded = self._log_training_metrics( + training_job_status.model_artifacts, + region=region, + profile=profile, + ) + mlflow.set_tag("qc_cli.command", "mlflow upload-metrics") + mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower()) + return history_uploaded def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} @@ -146,6 +240,26 @@ class MlflowTracker: if metrics: mlflow.log_metrics(metrics) + def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool: + with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: + archive_path = s3.download_file( + region, + profile, + model_artifacts, + os.path.join(temp_dir, "model.tar.gz"), + ) + metrics_data = read_training_metrics_from_tar(archive_path) + if metrics_data is None: + return False + metrics = parse_training_metrics(metrics_data) + for metric_step in metrics.steps: + if metric_step.metrics: + mlflow.log_metrics(metric_step.metrics, step=metric_step.step) + if metrics.summary: + mlflow.log_metrics(metrics.summary) + mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) + return True + def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: try: client.get_registered_model(name) diff --git a/src/tracking/upload.py b/src/tracking/upload.py new file mode 100644 index 0000000..5cf77a3 --- /dev/null +++ b/src/tracking/upload.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass + +from src import state as state_ops +from src.aws import sagemaker as sm_ops +from src.config import Config, MlflowMode +from src.tracking.mlflow import MlflowTracker + + +@dataclass(frozen=True) +class MetricsUploadResult: + run_id: str + registered_model_version: str | None = None + metrics_history_uploaded: bool = True + + +def upload_training_metrics( + *, + job_name: str, + config_path: str, + cfg: Config, + force: bool = False, +) -> MetricsUploadResult: + if cfg.mlflow.mode is MlflowMode.disabled: + raise RuntimeError("MLflow is disabled in config.yaml.") + + st = state_ops.store(config_path) + job_state = st.get_training_job(job_name) + if job_state.get("mlflow_metrics_uploaded") and not force: + return MetricsUploadResult( + run_id=str(job_state.get("mlflow_run_id") or ""), + registered_model_version=( + str(job_state["registered_model_version"]) + if job_state.get("registered_model_version") + else None + ), + metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)), + ) + + status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + if status.status != "Completed": + raise RuntimeError( + f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion." + ) + + tracker = MlflowTracker.from_config(cfg) + run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name)) + st.update_training_job(job_name, mlflow_run_id=run_id) + metrics_history_uploaded = tracker.upload_training_metrics( + run_id=run_id, + training_job_status=status, + region=cfg.aws.region, + profile=cfg.aws.profile, + ) + finalized = tracker.finalize_training_run( + run_id=run_id, + training_job_status=status, + region=cfg.aws.region, + profile=cfg.aws.profile, + command="mlflow upload-metrics", + ) + updates = { + "mlflow_metrics_uploaded": True, + "mlflow_metrics_history_uploaded": metrics_history_uploaded, + "mlflow_finalized_status": status.status, + } + if finalized.registered_model_version: + updates["registered_model_version"] = finalized.registered_model_version + st.update_training_job(job_name, **updates) + if finalized.registered_model_version: + st.set_latest_experiment_model_version(finalized.registered_model_version) + return MetricsUploadResult( + run_id=run_id, + registered_model_version=finalized.registered_model_version, + metrics_history_uploaded=metrics_history_uploaded, + )