This commit is contained in:
2026-06-12 12:21:41 -04:00
parent 5211d0af14
commit 4c33a016f0
4 changed files with 61 additions and 55 deletions

View File

@@ -105,7 +105,11 @@ mlflow:
tracking_server_name: your-tracking-server-name tracking_server_name: your-tracking-server-name
``` ```
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release. When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through
`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts
as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only.
An experiment version is an immutable trained-source artifact; it records that training produced a model, not that
the model is better than earlier versions or ready for release.
To open the managed SageMaker MLflow UI, request a fresh presigned URL: To open the managed SageMaker MLflow UI, request a fresh presigned URL:
@@ -224,10 +228,10 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
Current behavior: Current behavior:
1. `qc-cli train start` submits a SageMaker training job. 1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts. 2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion. 3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job. 4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: 5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
- `qc_cli.stage=experiment` - `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source` - `qc_cli.artifact_kind=trained_source`
- `qc_cli.source=sagemaker` - `qc_cli.source=sagemaker`

View File

@@ -70,7 +70,7 @@ def upload_metrics(
return return
try: try:
run_id = upload_training_metrics( result = upload_training_metrics(
job_name=job_name, job_name=job_name,
config_path=config, config_path=config,
cfg=cfg, cfg=cfg,
@@ -81,4 +81,9 @@ def upload_metrics(
raise typer.Exit(1) raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)

View File

@@ -64,45 +64,6 @@ def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
def _finalize_terminal_job(
*,
config_path: str,
cfg: Config,
status: sm_ops.TrainingJobStatus,
command: str,
) -> None:
if status.status not in _TERMINAL_STATUSES:
return
st = state_ops.store(config_path)
job_state = st.get_training_job(status.name)
run_id = job_state.get("mlflow_run_id")
if not run_id or job_state.get("mlflow_finalized_status"):
return
tracker = _tracker(cfg)
result = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command=command,
)
updates = {"mlflow_finalized_status": status.status}
if result.registered_model_version:
updates["registered_model_version"] = result.registered_model_version
st.update_training_job(status.name, **updates)
for warning in result.warnings:
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
if result.registered_model_version:
st.set_latest_experiment_model_version(result.registered_model_version)
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
def _wait_and_upload_metrics( def _wait_and_upload_metrics(
*, *,
job_name: str, job_name: str,
@@ -127,12 +88,20 @@ def _wait_and_upload_metrics(
if training_status.status != "Completed": if training_status.status != "Completed":
raise typer.Exit(1) raise typer.Exit(1)
try: try:
run_id = upload_training_metrics( result = upload_training_metrics(
job_name=job_name, job_name=job_name,
config_path=config_path, config_path=config_path,
cfg=cfg, cfg=cfg,
) )
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].") CONSOLE.print(
f"[green]✓[/green] Uploaded training metrics to MLflow run "
f"[cyan]{result.run_id}[/cyan]."
)
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
except Exception as e: except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
CONSOLE.print( CONSOLE.print(
@@ -258,11 +227,6 @@ def status(
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
_print_training_status(status) _print_training_status(status)
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command(name="list") @app.command(name="list")

View File

@@ -1,23 +1,38 @@
from dataclasses import dataclass
from src import state as state_ops from src import state as state_ops
from src.aws import sagemaker as sm_ops from src.aws import sagemaker as sm_ops
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
from src.tracking.mlflow import MlflowTracker from src.tracking.mlflow import MlflowTracker
@dataclass(frozen=True)
class MetricsUploadResult:
run_id: str
registered_model_version: str | None = None
def upload_training_metrics( def upload_training_metrics(
*, *,
job_name: str, job_name: str,
config_path: str, config_path: str,
cfg: Config, cfg: Config,
force: bool = False, force: bool = False,
) -> str: ) -> MetricsUploadResult:
if cfg.mlflow.mode is MlflowMode.disabled: if cfg.mlflow.mode is MlflowMode.disabled:
raise RuntimeError("MLflow is disabled in config.yaml.") raise RuntimeError("MLflow is disabled in config.yaml.")
st = state_ops.store(config_path) st = state_ops.store(config_path)
job_state = st.get_training_job(job_name) job_state = st.get_training_job(job_name)
if job_state.get("mlflow_metrics_uploaded") and not force: if job_state.get("mlflow_metrics_uploaded") and not force:
return str(job_state.get("mlflow_run_id") or "") return MetricsUploadResult(
run_id=str(job_state.get("mlflow_run_id") or ""),
registered_model_version=(
str(job_state["registered_model_version"])
if job_state.get("registered_model_version")
else None
),
)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if status.status != "Completed": if status.status != "Completed":
@@ -34,5 +49,23 @@ def upload_training_metrics(
region=cfg.aws.region, region=cfg.aws.region,
profile=cfg.aws.profile, profile=cfg.aws.profile,
) )
st.update_training_job(job_name, mlflow_metrics_uploaded=True) finalized = tracker.finalize_training_run(
return run_id run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command="mlflow upload-metrics",
)
updates = {
"mlflow_metrics_uploaded": True,
"mlflow_finalized_status": status.status,
}
if finalized.registered_model_version:
updates["registered_model_version"] = finalized.registered_model_version
st.update_training_job(job_name, **updates)
if finalized.registered_model_version:
st.set_latest_experiment_model_version(finalized.registered_model_version)
return MetricsUploadResult(
run_id=run_id,
registered_model_version=finalized.registered_model_version,
)