diff --git a/README.md b/README.md index 3ad7811..5ef0cec 100644 --- a/README.md +++ b/README.md @@ -78,12 +78,13 @@ To provision an MLflow tracking server, set: ```yaml mlflow: mode: create - tracking_server_name: your-tracking-server-name experiment_name: qc-cli-training registered_model_name: qc-cli-model register_trained_models: true ``` +In `create` mode, the CLI manages the tracking server name from `infra.stack_name`. + To use an existing MLflow tracking server, set: ```yaml diff --git a/src/commands/infra.py b/src/commands/infra.py index aa42e9c..c4f2aee 100644 --- a/src/commands/infra.py +++ b/src/commands/infra.py @@ -77,7 +77,8 @@ def setup( if outputs.get("SageMakerRoleArn"): CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}") if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"): - CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}") + mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name) + CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}") elif cfg.mlflow.mode is MlflowMode.existing: CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}") CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]") @@ -102,7 +103,7 @@ def status(config: str = CONFIG_OPT) -> None: if cfg.mlflow.mode is not MlflowMode.disabled: table.add_row( "MLflow", - cfg.mlflow.tracking_server_name or "-", + cfg.effective_mlflow_tracking_server_name or "-", "[red]unknown[/red]", "-", ) @@ -126,7 +127,7 @@ def status(config: str = CONFIG_OPT) -> None: if cfg.mlflow.mode is MlflowMode.create: table.add_row( "MLflow", - cfg.mlflow.tracking_server_name or "-", + outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name), "[green]managed[/green]", outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")), ) diff --git a/src/config.py b/src/config.py index 184ac0a..9212b59 100644 --- a/src/config.py +++ b/src/config.py @@ -1,5 +1,5 @@ import re -from enum import Enum +from enum import StrEnum from typing import Any, Literal, TypedDict from mypy_boto3_s3.literals import BucketLocationConstraintType @@ -7,13 +7,13 @@ from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType from pydantic import BaseModel, Field, model_validator -class MlflowMode(str, Enum): +class MlflowMode(StrEnum): disabled = "disabled" create = "create" existing = "existing" -class MlflowServerSize(str, Enum): +class MlflowServerSize(StrEnum): small = "Small" medium = "Medium" large = "Large" @@ -94,8 +94,8 @@ class MlflowConfig(BaseModel): @model_validator(mode="after") def require_tracking_server_name(self) -> "MlflowConfig": - if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name: - raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing") + if self.mode is MlflowMode.existing and not self.tracking_server_name: + raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing") return self @@ -105,3 +105,15 @@ class Config(BaseModel): s3: S3Config = Field(default_factory=S3Config) sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig) mlflow: MlflowConfig = Field(default_factory=MlflowConfig) + + @property + def managed_mlflow_tracking_server_name(self) -> str: + return f"{self.infra.stack_name}-mlflow" + + @property + def effective_mlflow_tracking_server_name(self) -> str | None: + if self.mlflow.mode is MlflowMode.disabled: + return None + if self.mlflow.mode is MlflowMode.existing: + return self.mlflow.tracking_server_name + return self.managed_mlflow_tracking_server_name diff --git a/src/infra/stack.py b/src/infra/stack.py index 0939b66..ed1119f 100644 --- a/src/infra/stack.py +++ b/src/infra/stack.py @@ -74,6 +74,7 @@ class QCStack(Stack): CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn) if config.mlflow.mode is MlflowMode.create: + tracking_server_name = config.managed_mlflow_tracking_server_name artifact_prefix = config.mlflow.artifact_prefix.strip("/") artifact_uri = ( f"s3://{data_bucket.bucket_name}/{artifact_prefix}/" @@ -145,14 +146,14 @@ class QCStack(Stack): "MlflowTrackingServer", artifact_store_uri=artifact_uri, role_arn=mlflow_role.attr_arn, - tracking_server_name=config.mlflow.tracking_server_name or "", + tracking_server_name=tracking_server_name, automatic_model_registration=config.mlflow.automatic_model_registration, mlflow_version=config.mlflow.mlflow_version, tracking_server_size=config.mlflow.tracking_server_size.value, weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start, ) - CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "") + CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name) CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn) CfnOutput(self, "MlflowArtifactUri", value=artifact_uri) CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn) diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index adba5e0..ac8dd87 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -48,13 +48,14 @@ class MlflowTracker: "Install with: qc-cli[mlflow]" ) from e - if not cfg.mlflow.tracking_server_name: - raise RuntimeError("mlflow.tracking_server_name is required when MLflow is enabled.") + tracking_server_name = cfg.effective_mlflow_tracking_server_name + if not tracking_server_name: + raise RuntimeError("MLflow tracking server name could not be resolved.") tracking_uri = aws_mlflow.get_tracking_server_arn( cfg.aws.region, cfg.aws.profile, - cfg.mlflow.tracking_server_name, + tracking_server_name, ) mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(cfg.mlflow.experiment_name)