Mlflow implementation #2

Merged
slalom merged 10 commits from ml-flow into main 2026-06-02 19:04:23 +00:00
5 changed files with 30 additions and 14 deletions
Showing only changes of commit e1c8d6574f - Show all commits

View File

@@ -78,12 +78,13 @@ To provision an MLflow tracking server, set:
```yaml
mlflow:
mode: create
tracking_server_name: your-tracking-server-name
experiment_name: qc-cli-training
registered_model_name: qc-cli-model
register_trained_models: true
```
In `create` mode, the CLI manages the tracking server name from `infra.stack_name`.
To use an existing MLflow tracking server, set:
```yaml

View File

@@ -77,7 +77,8 @@ def setup(
if outputs.get("SageMakerRoleArn"):
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
elif cfg.mlflow.mode is MlflowMode.existing:
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
@@ -102,7 +103,7 @@ def status(config: str = CONFIG_OPT) -> None:
if cfg.mlflow.mode is not MlflowMode.disabled:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
cfg.effective_mlflow_tracking_server_name or "-",
"[red]unknown[/red]",
"-",
)
@@ -126,7 +127,7 @@ def status(config: str = CONFIG_OPT) -> None:
if cfg.mlflow.mode is MlflowMode.create:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
"[green]managed[/green]",
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
)

View File

@@ -1,5 +1,5 @@
import re
from enum import Enum
from enum import StrEnum
from typing import Any, Literal, TypedDict
from mypy_boto3_s3.literals import BucketLocationConstraintType
@@ -7,13 +7,13 @@ from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator
class MlflowMode(str, Enum):
class MlflowMode(StrEnum):
disabled = "disabled"
create = "create"
existing = "existing"
class MlflowServerSize(str, Enum):
class MlflowServerSize(StrEnum):
small = "Small"
medium = "Medium"
large = "Large"
@@ -94,8 +94,8 @@ class MlflowConfig(BaseModel):
@model_validator(mode="after")
def require_tracking_server_name(self) -> "MlflowConfig":
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
if self.mode is MlflowMode.existing and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing")
return self
@@ -105,3 +105,15 @@ class Config(BaseModel):
s3: S3Config = Field(default_factory=S3Config)
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
@property
def managed_mlflow_tracking_server_name(self) -> str:
return f"{self.infra.stack_name}-mlflow"
@property
def effective_mlflow_tracking_server_name(self) -> str | None:
if self.mlflow.mode is MlflowMode.disabled:
return None
if self.mlflow.mode is MlflowMode.existing:
return self.mlflow.tracking_server_name
return self.managed_mlflow_tracking_server_name

View File

@@ -74,6 +74,7 @@ class QCStack(Stack):
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
if config.mlflow.mode is MlflowMode.create:
tracking_server_name = config.managed_mlflow_tracking_server_name
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
artifact_uri = (
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
@@ -145,14 +146,14 @@ class QCStack(Stack):
"MlflowTrackingServer",
artifact_store_uri=artifact_uri,
role_arn=mlflow_role.attr_arn,
tracking_server_name=config.mlflow.tracking_server_name or "",
tracking_server_name=tracking_server_name,
automatic_model_registration=config.mlflow.automatic_model_registration,
mlflow_version=config.mlflow.mlflow_version,
tracking_server_size=config.mlflow.tracking_server_size.value,
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
)
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name)
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)

View File

@@ -48,13 +48,14 @@ class MlflowTracker:
"Install with: qc-cli[mlflow]"
) from e
if not cfg.mlflow.tracking_server_name:
raise RuntimeError("mlflow.tracking_server_name is required when MLflow is enabled.")
tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_uri = aws_mlflow.get_tracking_server_arn(
cfg.aws.region,
cfg.aws.profile,
cfg.mlflow.tracking_server_name,
tracking_server_name,
)
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)