omit server name when created with config

This commit is contained in:
2026-05-27 10:23:53 -04:00
parent 35d25d8967
commit e1c8d6574f
5 changed files with 30 additions and 14 deletions

View File

@@ -78,12 +78,13 @@ To provision an MLflow tracking server, set:
```yaml ```yaml
mlflow: mlflow:
mode: create mode: create
tracking_server_name: your-tracking-server-name
experiment_name: qc-cli-training experiment_name: qc-cli-training
registered_model_name: qc-cli-model registered_model_name: qc-cli-model
register_trained_models: true register_trained_models: true
``` ```
In `create` mode, the CLI manages the tracking server name from `infra.stack_name`.
To use an existing MLflow tracking server, set: To use an existing MLflow tracking server, set:
```yaml ```yaml

View File

@@ -77,7 +77,8 @@ def setup(
if outputs.get("SageMakerRoleArn"): if outputs.get("SageMakerRoleArn"):
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}") CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"): if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}") mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
elif cfg.mlflow.mode is MlflowMode.existing: elif cfg.mlflow.mode is MlflowMode.existing:
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}") CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]") CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
@@ -102,7 +103,7 @@ def status(config: str = CONFIG_OPT) -> None:
if cfg.mlflow.mode is not MlflowMode.disabled: if cfg.mlflow.mode is not MlflowMode.disabled:
table.add_row( table.add_row(
"MLflow", "MLflow",
cfg.mlflow.tracking_server_name or "-", cfg.effective_mlflow_tracking_server_name or "-",
"[red]unknown[/red]", "[red]unknown[/red]",
"-", "-",
) )
@@ -126,7 +127,7 @@ def status(config: str = CONFIG_OPT) -> None:
if cfg.mlflow.mode is MlflowMode.create: if cfg.mlflow.mode is MlflowMode.create:
table.add_row( table.add_row(
"MLflow", "MLflow",
cfg.mlflow.tracking_server_name or "-", outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
"[green]managed[/green]", "[green]managed[/green]",
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")), outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
) )

View File

@@ -1,5 +1,5 @@
import re import re
from enum import Enum from enum import StrEnum
from typing import Any, Literal, TypedDict from typing import Any, Literal, TypedDict
from mypy_boto3_s3.literals import BucketLocationConstraintType from mypy_boto3_s3.literals import BucketLocationConstraintType
@@ -7,13 +7,13 @@ from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
class MlflowMode(str, Enum): class MlflowMode(StrEnum):
disabled = "disabled" disabled = "disabled"
create = "create" create = "create"
existing = "existing" existing = "existing"
class MlflowServerSize(str, Enum): class MlflowServerSize(StrEnum):
small = "Small" small = "Small"
medium = "Medium" medium = "Medium"
large = "Large" large = "Large"
@@ -94,8 +94,8 @@ class MlflowConfig(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def require_tracking_server_name(self) -> "MlflowConfig": def require_tracking_server_name(self) -> "MlflowConfig":
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name: if self.mode is MlflowMode.existing and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing") raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing")
return self return self
@@ -105,3 +105,15 @@ class Config(BaseModel):
s3: S3Config = Field(default_factory=S3Config) s3: S3Config = Field(default_factory=S3Config)
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig) sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
mlflow: MlflowConfig = Field(default_factory=MlflowConfig) mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
@property
def managed_mlflow_tracking_server_name(self) -> str:
return f"{self.infra.stack_name}-mlflow"
@property
def effective_mlflow_tracking_server_name(self) -> str | None:
if self.mlflow.mode is MlflowMode.disabled:
return None
if self.mlflow.mode is MlflowMode.existing:
return self.mlflow.tracking_server_name
return self.managed_mlflow_tracking_server_name

View File

@@ -74,6 +74,7 @@ class QCStack(Stack):
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn) CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
if config.mlflow.mode is MlflowMode.create: if config.mlflow.mode is MlflowMode.create:
tracking_server_name = config.managed_mlflow_tracking_server_name
artifact_prefix = config.mlflow.artifact_prefix.strip("/") artifact_prefix = config.mlflow.artifact_prefix.strip("/")
artifact_uri = ( artifact_uri = (
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/" f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
@@ -145,14 +146,14 @@ class QCStack(Stack):
"MlflowTrackingServer", "MlflowTrackingServer",
artifact_store_uri=artifact_uri, artifact_store_uri=artifact_uri,
role_arn=mlflow_role.attr_arn, role_arn=mlflow_role.attr_arn,
tracking_server_name=config.mlflow.tracking_server_name or "", tracking_server_name=tracking_server_name,
automatic_model_registration=config.mlflow.automatic_model_registration, automatic_model_registration=config.mlflow.automatic_model_registration,
mlflow_version=config.mlflow.mlflow_version, mlflow_version=config.mlflow.mlflow_version,
tracking_server_size=config.mlflow.tracking_server_size.value, tracking_server_size=config.mlflow.tracking_server_size.value,
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start, weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
) )
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "") CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name)
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn) CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri) CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn) CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)

View File

@@ -48,13 +48,14 @@ class MlflowTracker:
"Install with: qc-cli[mlflow]" "Install with: qc-cli[mlflow]"
) from e ) from e
if not cfg.mlflow.tracking_server_name: tracking_server_name = cfg.effective_mlflow_tracking_server_name
raise RuntimeError("mlflow.tracking_server_name is required when MLflow is enabled.") if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_uri = aws_mlflow.get_tracking_server_arn( tracking_uri = aws_mlflow.get_tracking_server_arn(
cfg.aws.region, cfg.aws.region,
cfg.aws.profile, cfg.aws.profile,
cfg.mlflow.tracking_server_name, tracking_server_name,
) )
mlflow.set_tracking_uri(tracking_uri) mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name) mlflow.set_experiment(cfg.mlflow.experiment_name)