mlflow not being an optional lin

This commit is contained in:
2026-05-29 14:29:05 -04:00
parent 58681cef82
commit 19fef8638b
4 changed files with 22 additions and 45 deletions

View File

@@ -93,12 +93,6 @@ mlflow:
tracking_server_name: your-tracking-server-name tracking_server_name: your-tracking-server-name
``` ```
Install the optional MLflow dependencies before enabling MLflow:
```bash
uv sync --extra mlflow
```
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias. When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias.
To open the managed SageMaker MLflow UI, request a fresh presigned URL: To open the managed SageMaker MLflow UI, request a fresh presigned URL:

View File

@@ -12,13 +12,9 @@ dependencies = [
"typer==0.25.0", "typer==0.25.0",
"boto3>=1.34,<1.42", "boto3>=1.34,<1.42",
"constructs>=10.0.0", "constructs>=10.0.0",
"mlflow>=3.0",
"pydantic>=2.13.3", "pydantic>=2.13.3",
"pyyaml>=6.0.3", "pyyaml>=6.0.3",
]
[project.optional-dependencies]
mlflow = [
"mlflow>=3.0",
"sagemaker-mlflow>=0.4.0", "sagemaker-mlflow>=0.4.0",
] ]

View File

@@ -4,6 +4,9 @@ import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Protocol from typing import Any, Protocol
import mlflow
from mlflow.tracking import MlflowClient
from src.aws import mlflow as aws_mlflow from src.aws import mlflow as aws_mlflow
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
@@ -30,7 +33,6 @@ class NoopTracker:
@dataclass(frozen=True) @dataclass(frozen=True)
class MlflowTracker: class MlflowTracker:
mlflow: Any
tracking_uri: str tracking_uri: str
experiment_name: str experiment_name: str
registered_model_name: str registered_model_name: str
@@ -43,14 +45,6 @@ class MlflowTracker:
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true") os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
try:
import mlflow
except ImportError as e:
raise RuntimeError(
"MLflow is enabled in config but optional dependencies are not installed. "
"Install with: qc-cli[mlflow]"
) from e
tracking_server_name = cfg.effective_mlflow_tracking_server_name tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name: if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.") raise RuntimeError("MLflow tracking server name could not be resolved.")
@@ -64,7 +58,6 @@ class MlflowTracker:
mlflow.set_experiment(cfg.mlflow.experiment_name) mlflow.set_experiment(cfg.mlflow.experiment_name)
return cls( return cls(
mlflow=mlflow,
tracking_uri=tracking_uri, tracking_uri=tracking_uri,
experiment_name=cfg.mlflow.experiment_name, experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name, registered_model_name=cfg.mlflow.registered_model_name,
@@ -72,7 +65,7 @@ class MlflowTracker:
) )
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
run = self.mlflow.start_run(run_name=training_job.job_name) run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id) run_id = str(run.info.run_id)
params = { params = {
@@ -90,21 +83,21 @@ class MlflowTracker:
} }
self._log_params(params) self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
self.mlflow.set_tags( mlflow.set_tags(
{ {
"qc_cli.stage": "prerelease", "qc_cli.stage": "prerelease",
"qc_cli.command": "train start", "qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name, "sagemaker.job_name": training_job.job_name,
} }
) )
self.mlflow.end_run() mlflow.end_run()
return run_id return run_id
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
if not run_id: if not run_id:
return None return None
with self.mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params( self._log_params(
{ {
"sagemaker.training_status": training_job_status.status, "sagemaker.training_status": training_job_status.status,
@@ -115,16 +108,16 @@ class MlflowTracker:
} }
) )
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
self.mlflow.set_tag("qc_cli.command", "train status") mlflow.set_tag("qc_cli.command", "train status")
if training_job_status.status != "Completed" or not training_job_status.model_artifacts: if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
self.mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return None return None
if not self.register_trained_models: if not self.register_trained_models:
return None return None
client = self.mlflow.tracking.MlflowClient() client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name) self._ensure_registered_model(client, self.registered_model_name)
version = client.create_model_version( version = client.create_model_version(
name=self.registered_model_name, name=self.registered_model_name,
@@ -137,14 +130,14 @@ class MlflowTracker:
) )
version_number = str(version.version) version_number = str(version.version)
self._set_alias(client, self.registered_model_name, "prerelease-latest", version_number) self._set_alias(client, self.registered_model_name, "prerelease-latest", version_number)
self.mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
self.mlflow.set_tag("qc_cli.registered_model_version", version_number) mlflow.set_tag("qc_cli.registered_model_version", version_number)
return version_number return version_number
def _log_params(self, params: dict[str, Any]) -> None: def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None} cleaned = {key: str(value) for key, value in params.items() if value is not None}
if cleaned: if cleaned:
self.mlflow.log_params(cleaned) mlflow.log_params(cleaned)
def _log_final_metrics(self, training_job: dict[str, Any]) -> None: def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
metrics = {} metrics = {}
@@ -154,14 +147,13 @@ class MlflowTracker:
if name and value is not None: if name and value is not None:
metrics[str(name)] = float(value) metrics[str(name)] = float(value)
if metrics: if metrics:
self.mlflow.log_metrics(metrics) mlflow.log_metrics(metrics)
def _ensure_registered_model(self, client: Any, name: str) -> None: def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try: try:
client.get_registered_model(name) client.get_registered_model(name)
except Exception: except Exception:
client.create_registered_model(name) client.create_registered_model(name)
def _set_alias(self, client: Any, name: str, alias: str, version: str) -> None: def _set_alias(self, client: MlflowClient, name: str, alias: str, version: str) -> None:
if hasattr(client, "set_registered_model_alias"): client.set_registered_model_alias(name, alias, version)
client.set_registered_model_alias(name, alias, version)

13
uv.lock generated
View File

@@ -2011,15 +2011,11 @@ dependencies = [
{ name = "aws-cdk-lib" }, { name = "aws-cdk-lib" },
{ name = "boto3" }, { name = "boto3" },
{ name = "constructs" }, { name = "constructs" },
{ name = "mlflow" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pyyaml" }, { name = "pyyaml" },
{ name = "typer" },
]
[package.optional-dependencies]
mlflow = [
{ name = "mlflow" },
{ name = "sagemaker-mlflow" }, { name = "sagemaker-mlflow" },
{ name = "typer" },
] ]
[package.dev-dependencies] [package.dev-dependencies]
@@ -2036,13 +2032,12 @@ requires-dist = [
{ name = "aws-cdk-lib", specifier = ">=2.180.0" }, { name = "aws-cdk-lib", specifier = ">=2.180.0" },
{ name = "boto3", specifier = ">=1.34,<1.42" }, { name = "boto3", specifier = ">=1.34,<1.42" },
{ name = "constructs", specifier = ">=10.0.0" }, { name = "constructs", specifier = ">=10.0.0" },
{ name = "mlflow", marker = "extra == 'mlflow'", specifier = ">=3.0" }, { name = "mlflow", specifier = ">=3.0" },
{ name = "pydantic", specifier = ">=2.13.3" }, { name = "pydantic", specifier = ">=2.13.3" },
{ name = "pyyaml", specifier = ">=6.0.3" }, { name = "pyyaml", specifier = ">=6.0.3" },
{ name = "sagemaker-mlflow", marker = "extra == 'mlflow'", specifier = ">=0.4.0" }, { name = "sagemaker-mlflow", specifier = ">=0.4.0" },
{ name = "typer", specifier = "==0.25.0" }, { name = "typer", specifier = "==0.25.0" },
] ]
provides-extras = ["mlflow"]
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [