Mlflow implementation #2
@@ -93,12 +93,6 @@ mlflow:
|
|||||||
tracking_server_name: your-tracking-server-name
|
tracking_server_name: your-tracking-server-name
|
||||||
```
|
```
|
||||||
|
|
||||||
Install the optional MLflow dependencies before enabling MLflow:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv sync --extra mlflow
|
|
||||||
```
|
|
||||||
|
|
||||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias.
|
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias.
|
||||||
|
|
||||||
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
||||||
|
|||||||
@@ -12,13 +12,9 @@ dependencies = [
|
|||||||
"typer==0.25.0",
|
"typer==0.25.0",
|
||||||
"boto3>=1.34,<1.42",
|
"boto3>=1.34,<1.42",
|
||||||
"constructs>=10.0.0",
|
"constructs>=10.0.0",
|
||||||
|
"mlflow>=3.0",
|
||||||
"pydantic>=2.13.3",
|
"pydantic>=2.13.3",
|
||||||
"pyyaml>=6.0.3",
|
"pyyaml>=6.0.3",
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
mlflow = [
|
|
||||||
"mlflow>=3.0",
|
|
||||||
"sagemaker-mlflow>=0.4.0",
|
"sagemaker-mlflow>=0.4.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ import os
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
from src.aws import mlflow as aws_mlflow
|
from src.aws import mlflow as aws_mlflow
|
||||||
from src.config import Config, MlflowMode
|
from src.config import Config, MlflowMode
|
||||||
|
|
||||||
@@ -30,7 +33,6 @@ class NoopTracker:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MlflowTracker:
|
class MlflowTracker:
|
||||||
mlflow: Any
|
|
||||||
tracking_uri: str
|
tracking_uri: str
|
||||||
experiment_name: str
|
experiment_name: str
|
||||||
registered_model_name: str
|
registered_model_name: str
|
||||||
@@ -43,14 +45,6 @@ class MlflowTracker:
|
|||||||
|
|
||||||
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
|
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
|
||||||
|
|
||||||
try:
|
|
||||||
import mlflow
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"MLflow is enabled in config but optional dependencies are not installed. "
|
|
||||||
"Install with: qc-cli[mlflow]"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||||
if not tracking_server_name:
|
if not tracking_server_name:
|
||||||
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
||||||
@@ -64,7 +58,6 @@ class MlflowTracker:
|
|||||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
mlflow=mlflow,
|
|
||||||
tracking_uri=tracking_uri,
|
tracking_uri=tracking_uri,
|
||||||
experiment_name=cfg.mlflow.experiment_name,
|
experiment_name=cfg.mlflow.experiment_name,
|
||||||
registered_model_name=cfg.mlflow.registered_model_name,
|
registered_model_name=cfg.mlflow.registered_model_name,
|
||||||
@@ -72,7 +65,7 @@ class MlflowTracker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
run = self.mlflow.start_run(run_name=training_job.job_name)
|
run = mlflow.start_run(run_name=training_job.job_name)
|
||||||
run_id = str(run.info.run_id)
|
run_id = str(run.info.run_id)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
@@ -90,21 +83,21 @@ class MlflowTracker:
|
|||||||
}
|
}
|
||||||
self._log_params(params)
|
self._log_params(params)
|
||||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||||
self.mlflow.set_tags(
|
mlflow.set_tags(
|
||||||
{
|
{
|
||||||
"qc_cli.stage": "prerelease",
|
"qc_cli.stage": "prerelease",
|
||||||
"qc_cli.command": "train start",
|
"qc_cli.command": "train start",
|
||||||
"sagemaker.job_name": training_job.job_name,
|
"sagemaker.job_name": training_job.job_name,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.mlflow.end_run()
|
mlflow.end_run()
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||||
if not run_id:
|
if not run_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with self.mlflow.start_run(run_id=run_id):
|
with mlflow.start_run(run_id=run_id):
|
||||||
self._log_params(
|
self._log_params(
|
||||||
{
|
{
|
||||||
"sagemaker.training_status": training_job_status.status,
|
"sagemaker.training_status": training_job_status.status,
|
||||||
@@ -115,16 +108,16 @@ class MlflowTracker:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
self._log_final_metrics(training_job_status.raw)
|
self._log_final_metrics(training_job_status.raw)
|
||||||
self.mlflow.set_tag("qc_cli.command", "train status")
|
mlflow.set_tag("qc_cli.command", "train status")
|
||||||
|
|
||||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||||
self.mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not self.register_trained_models:
|
if not self.register_trained_models:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
client = self.mlflow.tracking.MlflowClient()
|
client = MlflowClient()
|
||||||
self._ensure_registered_model(client, self.registered_model_name)
|
self._ensure_registered_model(client, self.registered_model_name)
|
||||||
version = client.create_model_version(
|
version = client.create_model_version(
|
||||||
name=self.registered_model_name,
|
name=self.registered_model_name,
|
||||||
@@ -137,14 +130,14 @@ class MlflowTracker:
|
|||||||
)
|
)
|
||||||
version_number = str(version.version)
|
version_number = str(version.version)
|
||||||
self._set_alias(client, self.registered_model_name, "prerelease-latest", version_number)
|
self._set_alias(client, self.registered_model_name, "prerelease-latest", version_number)
|
||||||
self.mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||||
self.mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||||
return version_number
|
return version_number
|
||||||
|
|
||||||
def _log_params(self, params: dict[str, Any]) -> None:
|
def _log_params(self, params: dict[str, Any]) -> None:
|
||||||
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||||
if cleaned:
|
if cleaned:
|
||||||
self.mlflow.log_params(cleaned)
|
mlflow.log_params(cleaned)
|
||||||
|
|
||||||
def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
|
def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
@@ -154,14 +147,13 @@ class MlflowTracker:
|
|||||||
if name and value is not None:
|
if name and value is not None:
|
||||||
metrics[str(name)] = float(value)
|
metrics[str(name)] = float(value)
|
||||||
if metrics:
|
if metrics:
|
||||||
self.mlflow.log_metrics(metrics)
|
mlflow.log_metrics(metrics)
|
||||||
|
|
||||||
def _ensure_registered_model(self, client: Any, name: str) -> None:
|
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||||
try:
|
try:
|
||||||
client.get_registered_model(name)
|
client.get_registered_model(name)
|
||||||
except Exception:
|
except Exception:
|
||||||
client.create_registered_model(name)
|
client.create_registered_model(name)
|
||||||
|
|
||||||
def _set_alias(self, client: Any, name: str, alias: str, version: str) -> None:
|
def _set_alias(self, client: MlflowClient, name: str, alias: str, version: str) -> None:
|
||||||
if hasattr(client, "set_registered_model_alias"):
|
|
||||||
client.set_registered_model_alias(name, alias, version)
|
client.set_registered_model_alias(name, alias, version)
|
||||||
|
|||||||
13
uv.lock
generated
13
uv.lock
generated
@@ -2011,15 +2011,11 @@ dependencies = [
|
|||||||
{ name = "aws-cdk-lib" },
|
{ name = "aws-cdk-lib" },
|
||||||
{ name = "boto3" },
|
{ name = "boto3" },
|
||||||
{ name = "constructs" },
|
{ name = "constructs" },
|
||||||
|
{ name = "mlflow" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pyyaml" },
|
{ name = "pyyaml" },
|
||||||
{ name = "typer" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.optional-dependencies]
|
|
||||||
mlflow = [
|
|
||||||
{ name = "mlflow" },
|
|
||||||
{ name = "sagemaker-mlflow" },
|
{ name = "sagemaker-mlflow" },
|
||||||
|
{ name = "typer" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
@@ -2036,13 +2032,12 @@ requires-dist = [
|
|||||||
{ name = "aws-cdk-lib", specifier = ">=2.180.0" },
|
{ name = "aws-cdk-lib", specifier = ">=2.180.0" },
|
||||||
{ name = "boto3", specifier = ">=1.34,<1.42" },
|
{ name = "boto3", specifier = ">=1.34,<1.42" },
|
||||||
{ name = "constructs", specifier = ">=10.0.0" },
|
{ name = "constructs", specifier = ">=10.0.0" },
|
||||||
{ name = "mlflow", marker = "extra == 'mlflow'", specifier = ">=3.0" },
|
{ name = "mlflow", specifier = ">=3.0" },
|
||||||
{ name = "pydantic", specifier = ">=2.13.3" },
|
{ name = "pydantic", specifier = ">=2.13.3" },
|
||||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||||
{ name = "sagemaker-mlflow", marker = "extra == 'mlflow'", specifier = ">=0.4.0" },
|
{ name = "sagemaker-mlflow", specifier = ">=0.4.0" },
|
||||||
{ name = "typer", specifier = "==0.25.0" },
|
{ name = "typer", specifier = "==0.25.0" },
|
||||||
]
|
]
|
||||||
provides-extras = ["mlflow"]
|
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
dev = [
|
dev = [
|
||||||
|
|||||||
Reference in New Issue
Block a user