import re from enum import Enum from typing import Any, Literal, TypedDict from mypy_boto3_s3.literals import BucketLocationConstraintType from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType from pydantic import BaseModel, Field, model_validator class MlflowMode(str, Enum): disabled = "disabled" create = "create" existing = "existing" class MlflowServerSize(str, Enum): small = "Small" medium = "Medium" large = "Large" class Boto3SessionKwargs(TypedDict): profile_name: str region_name: str class AwsConfig(BaseModel): region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1" profile: str = "default" @property def boto3_session(self) -> Boto3SessionKwargs: return {"profile_name": self.profile, "region_name": self.region} DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds" GENERATED_STACK_PREFIX = "qc-cli-mlops-" class InfraConfig(BaseModel): stack_name: str @property def effective_bootstrap_qualifier(self) -> str: sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower()) if not sanitized: return DEFAULT_BOOTSTRAP_QUALIFIER if self.stack_name.startswith(GENERATED_STACK_PREFIX): suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower()) if suffix: return f"q{suffix}"[:10] return f"q{sanitized}"[:10] @property def effective_toolkit_stack_name(self) -> str: if self.stack_name.startswith(GENERATED_STACK_PREFIX): suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX)) if suffix: return f"{self.stack_name}-bootstrap" return f"{self.stack_name}-bootstrap" class S3Config(BaseModel): bucket: str = "my-qc-mlops-bucket" data_prefix: str = "data/" model_prefix: str = "models/" class TrainingConfig(BaseModel): instance_type: TrainingInstanceTypeType = "ml.m5.xlarge" instance_count: int = 1 image_uri: str = "" entry_point: str | None = None source_dir: str | None = None hyperparameters: dict[str, Any] = Field(default_factory=dict) class SageMakerConfig(BaseModel): role_name: str = "" training: TrainingConfig = Field(default_factory=TrainingConfig) class MlflowConfig(BaseModel): mode: MlflowMode = MlflowMode.disabled tracking_server_name: str | None = None experiment_name: str = "qc-cli-training" registered_model_name: str = "qc-cli-model" register_trained_models: bool = True artifact_prefix: str = "mlflow/" tracking_server_size: MlflowServerSize = MlflowServerSize.small mlflow_version: str | None = None automatic_model_registration: bool = False weekly_maintenance_window_start: str | None = None @model_validator(mode="after") def require_tracking_server_name(self) -> "MlflowConfig": if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name: raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing") return self class Config(BaseModel): infra: InfraConfig aws: AwsConfig = Field(default_factory=AwsConfig) s3: S3Config = Field(default_factory=S3Config) sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig) mlflow: MlflowConfig = Field(default_factory=MlflowConfig)