This commit is contained in:
2026-06-12 14:10:52 -04:00
parent 3ec9c7b57a
commit 20cd3f9794
5 changed files with 37 additions and 17 deletions

View File

@@ -238,9 +238,9 @@ Current behavior:
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version. 6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. 7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. The explicit metrics Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the
upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow
the JSON as a run artifact: step and stores the JSON as a run artifact:
```json ```json
{ {
@@ -253,8 +253,9 @@ the JSON as a run artifact:
``` ```
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and
strictly increasing. A missing or malformed metrics artifact fails the upload command without affecting the trained strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues
model or model registration. model registration without per-epoch history. A malformed metrics artifact still fails the upload command without
affecting the trained model or model registration.
Future release aliases such as `v1` or `production` can point at a selected deployable artifact. Future release aliases such as `v1` or `production` can point at a selected deployable artifact.

View File

@@ -80,7 +80,13 @@ def upload_metrics(
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]") CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
raise typer.Exit(1) raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].") if result.metrics_history_uploaded:
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
else:
CONSOLE.print(
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
)
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]") CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
if result.registered_model_version: if result.registered_model_version:
CONSOLE.print( CONSOLE.print(

View File

@@ -93,10 +93,16 @@ def _wait_and_upload_metrics(
config_path=config_path, config_path=config_path,
cfg=cfg, cfg=cfg,
) )
CONSOLE.print( if result.metrics_history_uploaded:
f"[green]✓[/green] Uploaded training metrics to MLflow run " CONSOLE.print(
f"[cyan]{result.run_id}[/cyan]." f"[green]✓[/green] Uploaded training metrics to MLflow run "
) f"[cyan]{result.run_id}[/cyan]."
)
else:
CONSOLE.print(
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
"Uploaded SageMaker final metrics only.[/yellow]"
)
if result.registered_model_version: if result.registered_model_version:
CONSOLE.print( CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "

View File

@@ -40,7 +40,7 @@ class Tracker(Protocol):
training_job_status: Any, training_job_status: Any,
region: str, region: str,
profile: str, profile: str,
) -> None: ... ) -> bool: ...
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -69,7 +69,7 @@ class NoopTracker:
training_job_status: Any, training_job_status: Any,
region: str, region: str,
profile: str, profile: str,
) -> None: ) -> bool:
raise RuntimeError("MLflow is disabled.") raise RuntimeError("MLflow is disabled.")
@@ -208,7 +208,7 @@ class MlflowTracker:
training_job_status: Any, training_job_status: Any,
region: str, region: str,
profile: str, profile: str,
) -> None: ) -> bool:
if not training_job_status.model_artifacts: if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.") raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
@@ -216,12 +216,14 @@ class MlflowTracker:
with mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status)) self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
self._log_training_metrics( history_uploaded = self._log_training_metrics(
training_job_status.model_artifacts, training_job_status.model_artifacts,
region=region, region=region,
profile=profile, profile=profile,
) )
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics") mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
return history_uploaded
def _log_params(self, params: dict[str, Any]) -> None: def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None} cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -238,7 +240,7 @@ class MlflowTracker:
if metrics: if metrics:
mlflow.log_metrics(metrics) mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None: def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file( archive_path = s3.download_file(
region, region,
@@ -248,7 +250,7 @@ class MlflowTracker:
) )
metrics_data = read_training_metrics_from_tar(archive_path) metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None: if metrics_data is None:
raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.") return False
metrics = parse_training_metrics(metrics_data) metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps: for metric_step in metrics.steps:
if metric_step.metrics: if metric_step.metrics:
@@ -256,6 +258,7 @@ class MlflowTracker:
if metrics.summary: if metrics.summary:
mlflow.log_metrics(metrics.summary) mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
return True
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try: try:

View File

@@ -10,6 +10,7 @@ from src.tracking.mlflow import MlflowTracker
class MetricsUploadResult: class MetricsUploadResult:
run_id: str run_id: str
registered_model_version: str | None = None registered_model_version: str | None = None
metrics_history_uploaded: bool = True
def upload_training_metrics( def upload_training_metrics(
@@ -32,6 +33,7 @@ def upload_training_metrics(
if job_state.get("registered_model_version") if job_state.get("registered_model_version")
else None else None
), ),
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
) )
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
@@ -43,7 +45,7 @@ def upload_training_metrics(
tracker = MlflowTracker.from_config(cfg) tracker = MlflowTracker.from_config(cfg)
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name)) run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
st.update_training_job(job_name, mlflow_run_id=run_id) st.update_training_job(job_name, mlflow_run_id=run_id)
tracker.upload_training_metrics( metrics_history_uploaded = tracker.upload_training_metrics(
run_id=run_id, run_id=run_id,
training_job_status=status, training_job_status=status,
region=cfg.aws.region, region=cfg.aws.region,
@@ -58,6 +60,7 @@ def upload_training_metrics(
) )
updates = { updates = {
"mlflow_metrics_uploaded": True, "mlflow_metrics_uploaded": True,
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
"mlflow_finalized_status": status.status, "mlflow_finalized_status": status.status,
} }
if finalized.registered_model_version: if finalized.registered_model_version:
@@ -68,4 +71,5 @@ def upload_training_metrics(
return MetricsUploadResult( return MetricsUploadResult(
run_id=run_id, run_id=run_id,
registered_model_version=finalized.registered_model_version, registered_model_version=finalized.registered_model_version,
metrics_history_uploaded=metrics_history_uploaded,
) )