This commit is contained in:
2026-06-12 14:10:52 -04:00
parent 3ec9c7b57a
commit 20cd3f9794
5 changed files with 37 additions and 17 deletions

View File

@@ -238,9 +238,9 @@ Current behavior:
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. The explicit metrics
upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores
the JSON as a run artifact:
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the
explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow
step and stores the JSON as a run artifact:
```json
{
@@ -253,8 +253,9 @@ the JSON as a run artifact:
```
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and
strictly increasing. A missing or malformed metrics artifact fails the upload command without affecting the trained
model or model registration.
strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues
model registration without per-epoch history. A malformed metrics artifact still fails the upload command without
affecting the trained model or model registration.
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.

View File

@@ -80,7 +80,13 @@ def upload_metrics(
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
raise typer.Exit(1)
if result.metrics_history_uploaded:
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
else:
CONSOLE.print(
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
)
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
if result.registered_model_version:
CONSOLE.print(

View File

@@ -93,10 +93,16 @@ def _wait_and_upload_metrics(
config_path=config_path,
cfg=cfg,
)
if result.metrics_history_uploaded:
CONSOLE.print(
f"[green]✓[/green] Uploaded training metrics to MLflow run "
f"[cyan]{result.run_id}[/cyan]."
)
else:
CONSOLE.print(
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
"Uploaded SageMaker final metrics only.[/yellow]"
)
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "

View File

@@ -40,7 +40,7 @@ class Tracker(Protocol):
training_job_status: Any,
region: str,
profile: str,
) -> None: ...
) -> bool: ...
@dataclass(frozen=True)
@@ -69,7 +69,7 @@ class NoopTracker:
training_job_status: Any,
region: str,
profile: str,
) -> None:
) -> bool:
raise RuntimeError("MLflow is disabled.")
@@ -208,7 +208,7 @@ class MlflowTracker:
training_job_status: Any,
region: str,
profile: str,
) -> None:
) -> bool:
if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
@@ -216,12 +216,14 @@ class MlflowTracker:
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
self._log_training_metrics(
history_uploaded = self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
return history_uploaded
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -238,7 +240,7 @@ class MlflowTracker:
if metrics:
mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None:
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
@@ -248,7 +250,7 @@ class MlflowTracker:
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.")
return False
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
@@ -256,6 +258,7 @@ class MlflowTracker:
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
return True
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try:

View File

@@ -10,6 +10,7 @@ from src.tracking.mlflow import MlflowTracker
class MetricsUploadResult:
run_id: str
registered_model_version: str | None = None
metrics_history_uploaded: bool = True
def upload_training_metrics(
@@ -32,6 +33,7 @@ def upload_training_metrics(
if job_state.get("registered_model_version")
else None
),
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
@@ -43,7 +45,7 @@ def upload_training_metrics(
tracker = MlflowTracker.from_config(cfg)
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
st.update_training_job(job_name, mlflow_run_id=run_id)
tracker.upload_training_metrics(
metrics_history_uploaded = tracker.upload_training_metrics(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
@@ -58,6 +60,7 @@ def upload_training_metrics(
)
updates = {
"mlflow_metrics_uploaded": True,
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
"mlflow_finalized_status": status.status,
}
if finalized.registered_model_version:
@@ -68,4 +71,5 @@ def upload_training_metrics(
return MetricsUploadResult(
run_id=run_id,
registered_model_version=finalized.registered_model_version,
metrics_history_uploaded=metrics_history_uploaded,
)