Mlflow implementation #2
@@ -78,12 +78,13 @@ To provision an MLflow tracking server, set:
|
||||
```yaml
|
||||
mlflow:
|
||||
mode: create
|
||||
tracking_server_name: your-tracking-server-name
|
||||
experiment_name: qc-cli-training
|
||||
registered_model_name: qc-cli-model
|
||||
register_trained_models: true
|
||||
```
|
||||
|
||||
In `create` mode, the CLI manages the tracking server name from `infra.stack_name`.
|
||||
|
||||
To use an existing MLflow tracking server, set:
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -77,7 +77,8 @@ def setup(
|
||||
if outputs.get("SageMakerRoleArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
||||
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
|
||||
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
|
||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
||||
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
||||
@@ -102,7 +103,7 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
if cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
cfg.effective_mlflow_tracking_server_name or "-",
|
||||
"[red]unknown[/red]",
|
||||
"-",
|
||||
)
|
||||
@@ -126,7 +127,7 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
if cfg.mlflow.mode is MlflowMode.create:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
|
||||
"[green]managed[/green]",
|
||||
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
@@ -7,13 +7,13 @@ from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MlflowMode(str, Enum):
|
||||
class MlflowMode(StrEnum):
|
||||
disabled = "disabled"
|
||||
create = "create"
|
||||
existing = "existing"
|
||||
|
||||
|
||||
class MlflowServerSize(str, Enum):
|
||||
class MlflowServerSize(StrEnum):
|
||||
small = "Small"
|
||||
medium = "Medium"
|
||||
large = "Large"
|
||||
@@ -94,8 +94,8 @@ class MlflowConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_tracking_server_name(self) -> "MlflowConfig":
|
||||
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
|
||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
|
||||
if self.mode is MlflowMode.existing and not self.tracking_server_name:
|
||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing")
|
||||
return self
|
||||
|
||||
|
||||
@@ -105,3 +105,15 @@ class Config(BaseModel):
|
||||
s3: S3Config = Field(default_factory=S3Config)
|
||||
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
||||
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|
||||
|
||||
@property
|
||||
def managed_mlflow_tracking_server_name(self) -> str:
|
||||
return f"{self.infra.stack_name}-mlflow"
|
||||
|
||||
@property
|
||||
def effective_mlflow_tracking_server_name(self) -> str | None:
|
||||
if self.mlflow.mode is MlflowMode.disabled:
|
||||
return None
|
||||
if self.mlflow.mode is MlflowMode.existing:
|
||||
return self.mlflow.tracking_server_name
|
||||
return self.managed_mlflow_tracking_server_name
|
||||
|
||||
@@ -74,6 +74,7 @@ class QCStack(Stack):
|
||||
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
|
||||
|
||||
if config.mlflow.mode is MlflowMode.create:
|
||||
tracking_server_name = config.managed_mlflow_tracking_server_name
|
||||
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
|
||||
artifact_uri = (
|
||||
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
|
||||
@@ -145,14 +146,14 @@ class QCStack(Stack):
|
||||
"MlflowTrackingServer",
|
||||
artifact_store_uri=artifact_uri,
|
||||
role_arn=mlflow_role.attr_arn,
|
||||
tracking_server_name=config.mlflow.tracking_server_name or "",
|
||||
tracking_server_name=tracking_server_name,
|
||||
automatic_model_registration=config.mlflow.automatic_model_registration,
|
||||
mlflow_version=config.mlflow.mlflow_version,
|
||||
tracking_server_size=config.mlflow.tracking_server_size.value,
|
||||
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
|
||||
)
|
||||
|
||||
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
|
||||
CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name)
|
||||
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
|
||||
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
|
||||
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)
|
||||
|
||||
@@ -48,13 +48,14 @@ class MlflowTracker:
|
||||
"Install with: qc-cli[mlflow]"
|
||||
) from e
|
||||
|
||||
if not cfg.mlflow.tracking_server_name:
|
||||
raise RuntimeError("mlflow.tracking_server_name is required when MLflow is enabled.")
|
||||
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||
if not tracking_server_name:
|
||||
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
||||
|
||||
tracking_uri = aws_mlflow.get_tracking_server_arn(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.mlflow.tracking_server_name,
|
||||
tracking_server_name,
|
||||
)
|
||||
mlflow.set_tracking_uri(tracking_uri)
|
||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||
|
||||
Reference in New Issue
Block a user