ai-hub #3

Merged
slalom merged 17 commits from ai-hub into main 2026-06-03 21:06:06 +00:00
4 changed files with 52 additions and 8 deletions
Showing only changes of commit 57a8a0a9c4 - Show all commits

View File

@@ -93,7 +93,7 @@ mlflow:
tracking_server_name: your-tracking-server-name tracking_server_name: your-tracking-server-name
``` ```
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias. When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release.
To open the managed SageMaker MLflow UI, request a fresh presigned URL: To open the managed SageMaker MLflow UI, request a fresh presigned URL:
@@ -155,6 +155,46 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
## Model lifecycle
The CLI uses neutral experiment naming for trained artifacts and reserves release terminology for an explicit promotion step.
Current behavior:
1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state.
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
- `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source`
- `qc_cli.source=sagemaker`
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
Planned AI Hub extension:
1. AI Hub compile or quantize will create deployable derived artifacts from a trained-source experiment.
2. Derived artifacts will keep lineage back to the source experiment version instead of replacing it.
3. Release aliases such as `v1` or `production` will point at the selected deployable artifact.
Example future metadata:
```text
qc-cli-model version 12
qc_cli.stage=experiment
qc_cli.artifact_kind=trained_source
qc_cli.source=sagemaker
qc-cli-model-aihub version 3
qc_cli.stage=ai_hub_compiled
qc_cli.artifact_kind=deployable
qc_cli.parent_registered_model_name=qc-cli-model
qc_cli.parent_model_version=12
qc_cli.runtime=tflite
qc_cli.quantization=int8
qc_cli.target_device=Samsung Galaxy S25
```
In that flow, `experiment-latest` remains a training convenience alias. Release selection is a separate promotion decision based on the derived artifact, not on the experiment name.
## AWS permissions required ## AWS permissions required
The IAM user or role running the CLI needs: The IAM user or role running the CLI needs:

View File

@@ -148,8 +148,8 @@ def status(
updates["registered_model_version"] = version updates["registered_model_version"] = version
st.update_training_job(job_name, **updates) st.update_training_job(job_name, **updates)
if version: if version:
st.set_latest_prerelease_model_version(version) st.set_latest_experiment_model_version(version)
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]prerelease-latest[/cyan])") CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
if run_id and cfg.mlflow.mode is not MlflowMode.disabled: if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")

View File

@@ -48,8 +48,8 @@ class CliStateStore:
state["training_jobs"] = jobs state["training_jobs"] = jobs
self._write(state) self._write(state)
def set_latest_prerelease_model_version(self, version: str) -> None: def set_latest_experiment_model_version(self, version: str) -> None:
self.update(latest_prerelease_model_version=version) self.update(latest_experiment_model_version=version)
def _write(self, state: dict[str, Any]) -> None: def _write(self, state: dict[str, Any]) -> None:
with open(self.path, "w") as f: with open(self.path, "w") as f:

View File

@@ -78,7 +78,9 @@ class MlflowTracker:
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags( mlflow.set_tags(
{ {
"qc_cli.stage": "prerelease", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "train start", "qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name, "sagemaker.job_name": training_job.job_name,
} }
@@ -117,12 +119,14 @@ class MlflowTracker:
source=training_job_status.model_artifacts, source=training_job_status.model_artifacts,
run_id=run_id, run_id=run_id,
tags={ tags={
"qc_cli.stage": "prerelease", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"sagemaker.job_name": training_job_status.name, "sagemaker.job_name": training_job_status.name,
}, },
) )
version_number = str(version.version) version_number = str(version.version)
client.set_registered_model_alias(self.registered_model_name, "prerelease-latest", version_number) client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number) mlflow.set_tag("qc_cli.registered_model_version", version_number)
return version_number return version_number