simplify
This commit is contained in:
10
README.md
10
README.md
@@ -105,7 +105,11 @@ mlflow:
|
||||
tracking_server_name: your-tracking-server-name
|
||||
```
|
||||
|
||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release.
|
||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through
|
||||
`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts
|
||||
as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only.
|
||||
An experiment version is an immutable trained-source artifact; it records that training produced a model, not that
|
||||
the model is better than earlier versions or ready for release.
|
||||
|
||||
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
||||
|
||||
@@ -224,10 +228,10 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
|
||||
Current behavior:
|
||||
|
||||
1. `qc-cli train start` submits a SageMaker training job.
|
||||
2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts.
|
||||
2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
|
||||
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
|
||||
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
|
||||
5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
||||
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
|
||||
- `qc_cli.stage=experiment`
|
||||
- `qc_cli.artifact_kind=trained_source`
|
||||
- `qc_cli.source=sagemaker`
|
||||
|
||||
@@ -70,7 +70,7 @@ def upload_metrics(
|
||||
return
|
||||
|
||||
try:
|
||||
run_id = upload_training_metrics(
|
||||
result = upload_training_metrics(
|
||||
job_name=job_name,
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
@@ -81,4 +81,9 @@ def upload_metrics(
|
||||
raise typer.Exit(1)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
|
||||
if result.registered_model_version:
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
|
||||
@@ -64,45 +64,6 @@ def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
|
||||
def _finalize_terminal_job(
|
||||
*,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
status: sm_ops.TrainingJobStatus,
|
||||
command: str,
|
||||
) -> None:
|
||||
if status.status not in _TERMINAL_STATUSES:
|
||||
return
|
||||
|
||||
st = state_ops.store(config_path)
|
||||
job_state = st.get_training_job(status.name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
if not run_id or job_state.get("mlflow_finalized_status"):
|
||||
return
|
||||
|
||||
tracker = _tracker(cfg)
|
||||
result = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
command=command,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if result.registered_model_version:
|
||||
updates["registered_model_version"] = result.registered_model_version
|
||||
st.update_training_job(status.name, **updates)
|
||||
|
||||
for warning in result.warnings:
|
||||
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
|
||||
if result.registered_model_version:
|
||||
st.set_latest_experiment_model_version(result.registered_model_version)
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
|
||||
|
||||
def _wait_and_upload_metrics(
|
||||
*,
|
||||
job_name: str,
|
||||
@@ -127,12 +88,20 @@ def _wait_and_upload_metrics(
|
||||
if training_status.status != "Completed":
|
||||
raise typer.Exit(1)
|
||||
try:
|
||||
run_id = upload_training_metrics(
|
||||
result = upload_training_metrics(
|
||||
job_name=job_name,
|
||||
config_path=config_path,
|
||||
cfg=cfg,
|
||||
)
|
||||
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].")
|
||||
CONSOLE.print(
|
||||
f"[green]✓[/green] Uploaded training metrics to MLflow run "
|
||||
f"[cyan]{result.run_id}[/cyan]."
|
||||
)
|
||||
if result.registered_model_version:
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||
CONSOLE.print(
|
||||
@@ -258,11 +227,6 @@ def status(
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
_print_training_status(status)
|
||||
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
|
||||
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
|
||||
@@ -1,23 +1,38 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src import state as state_ops
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.config import Config, MlflowMode
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricsUploadResult:
|
||||
run_id: str
|
||||
registered_model_version: str | None = None
|
||||
|
||||
|
||||
def upload_training_metrics(
|
||||
*,
|
||||
job_name: str,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
force: bool = False,
|
||||
) -> str:
|
||||
) -> MetricsUploadResult:
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
raise RuntimeError("MLflow is disabled in config.yaml.")
|
||||
|
||||
st = state_ops.store(config_path)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_metrics_uploaded") and not force:
|
||||
return str(job_state.get("mlflow_run_id") or "")
|
||||
return MetricsUploadResult(
|
||||
run_id=str(job_state.get("mlflow_run_id") or ""),
|
||||
registered_model_version=(
|
||||
str(job_state["registered_model_version"])
|
||||
if job_state.get("registered_model_version")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if status.status != "Completed":
|
||||
@@ -34,5 +49,23 @@ def upload_training_metrics(
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
)
|
||||
st.update_training_job(job_name, mlflow_metrics_uploaded=True)
|
||||
return run_id
|
||||
finalized = tracker.finalize_training_run(
|
||||
run_id=run_id,
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
command="mlflow upload-metrics",
|
||||
)
|
||||
updates = {
|
||||
"mlflow_metrics_uploaded": True,
|
||||
"mlflow_finalized_status": status.status,
|
||||
}
|
||||
if finalized.registered_model_version:
|
||||
updates["registered_model_version"] = finalized.registered_model_version
|
||||
st.update_training_job(job_name, **updates)
|
||||
if finalized.registered_model_version:
|
||||
st.set_latest_experiment_model_version(finalized.registered_model_version)
|
||||
return MetricsUploadResult(
|
||||
run_id=run_id,
|
||||
registered_model_version=finalized.registered_model_version,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user