From 4c33a016f0b7243e85df5d7c426c9d5ac8823581 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 12:21:41 -0400 Subject: [PATCH] simplify --- README.md | 10 +++++--- src/commands/mlflow.py | 9 +++++-- src/commands/train.py | 56 ++++++++---------------------------------- src/tracking/upload.py | 41 ++++++++++++++++++++++++++++--- 4 files changed, 61 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index afb8dd7..d750381 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,11 @@ mlflow: tracking_server_name: your-tracking-server-name ``` -When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release. +When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through +`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts +as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only. +An experiment version is an immutable trained-source artifact; it records that training produced a model, not that +the model is better than earlier versions or ready for release. To open the managed SageMaker MLflow UI, request a fresh presigned URL: @@ -224,10 +228,10 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts. +2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow. 3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion. 4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job. -5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: +5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` - `qc_cli.source=sagemaker` diff --git a/src/commands/mlflow.py b/src/commands/mlflow.py index 401a3d4..f282e2f 100644 --- a/src/commands/mlflow.py +++ b/src/commands/mlflow.py @@ -70,7 +70,7 @@ def upload_metrics( return try: - run_id = upload_training_metrics( + result = upload_training_metrics( job_name=job_name, config_path=config, cfg=cfg, @@ -81,4 +81,9 @@ def upload_metrics( raise typer.Exit(1) CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") - CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") + CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]") + if result.registered_model_version: + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) diff --git a/src/commands/train.py b/src/commands/train.py index a6c3c1a..4fe0b1c 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -64,45 +64,6 @@ def _print_training_status(status: sm_ops.TrainingJobStatus) -> None: CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") -def _finalize_terminal_job( - *, - config_path: str, - cfg: Config, - status: sm_ops.TrainingJobStatus, - command: str, -) -> None: - if status.status not in _TERMINAL_STATUSES: - return - - st = state_ops.store(config_path) - job_state = st.get_training_job(status.name) - run_id = job_state.get("mlflow_run_id") - if not run_id or job_state.get("mlflow_finalized_status"): - return - - tracker = _tracker(cfg) - result = tracker.finalize_training_run( - run_id=str(run_id), - training_job_status=status, - region=cfg.aws.region, - profile=cfg.aws.profile, - command=command, - ) - updates = {"mlflow_finalized_status": status.status} - if result.registered_model_version: - updates["registered_model_version"] = result.registered_model_version - st.update_training_job(status.name, **updates) - - for warning in result.warnings: - CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]") - if result.registered_model_version: - st.set_latest_experiment_model_version(result.registered_model_version) - CONSOLE.print( - f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " - "([cyan]experiment-latest[/cyan])" - ) - - def _wait_and_upload_metrics( *, job_name: str, @@ -127,12 +88,20 @@ def _wait_and_upload_metrics( if training_status.status != "Completed": raise typer.Exit(1) try: - run_id = upload_training_metrics( + result = upload_training_metrics( job_name=job_name, config_path=config_path, cfg=cfg, ) - CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].") + CONSOLE.print( + f"[green]✓[/green] Uploaded training metrics to MLflow run " + f"[cyan]{result.run_id}[/cyan]." + ) + if result.registered_model_version: + CONSOLE.print( + f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " + "([cyan]experiment-latest[/cyan])" + ) except Exception as e: CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") CONSOLE.print( @@ -258,11 +227,6 @@ def status( status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) _print_training_status(status) - _finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status") - - job_state = st.get_training_job(job_name) - if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: - CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") @app.command(name="list") diff --git a/src/tracking/upload.py b/src/tracking/upload.py index add9d98..78e6a41 100644 --- a/src/tracking/upload.py +++ b/src/tracking/upload.py @@ -1,23 +1,38 @@ +from dataclasses import dataclass + from src import state as state_ops from src.aws import sagemaker as sm_ops from src.config import Config, MlflowMode from src.tracking.mlflow import MlflowTracker +@dataclass(frozen=True) +class MetricsUploadResult: + run_id: str + registered_model_version: str | None = None + + def upload_training_metrics( *, job_name: str, config_path: str, cfg: Config, force: bool = False, -) -> str: +) -> MetricsUploadResult: if cfg.mlflow.mode is MlflowMode.disabled: raise RuntimeError("MLflow is disabled in config.yaml.") st = state_ops.store(config_path) job_state = st.get_training_job(job_name) if job_state.get("mlflow_metrics_uploaded") and not force: - return str(job_state.get("mlflow_run_id") or "") + return MetricsUploadResult( + run_id=str(job_state.get("mlflow_run_id") or ""), + registered_model_version=( + str(job_state["registered_model_version"]) + if job_state.get("registered_model_version") + else None + ), + ) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) if status.status != "Completed": @@ -34,5 +49,23 @@ def upload_training_metrics( region=cfg.aws.region, profile=cfg.aws.profile, ) - st.update_training_job(job_name, mlflow_metrics_uploaded=True) - return run_id + finalized = tracker.finalize_training_run( + run_id=run_id, + training_job_status=status, + region=cfg.aws.region, + profile=cfg.aws.profile, + command="mlflow upload-metrics", + ) + updates = { + "mlflow_metrics_uploaded": True, + "mlflow_finalized_status": status.status, + } + if finalized.registered_model_version: + updates["registered_model_version"] = finalized.registered_model_version + st.update_training_job(job_name, **updates) + if finalized.registered_model_version: + st.set_latest_experiment_model_version(finalized.registered_model_version) + return MetricsUploadResult( + run_id=run_id, + registered_model_version=finalized.registered_model_version, + )