108 lines
3.4 KiB
Python
108 lines
3.4 KiB
Python
import re
|
|
from enum import Enum
|
|
from typing import Any, Literal, TypedDict
|
|
|
|
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
|
|
class MlflowMode(str, Enum):
|
|
disabled = "disabled"
|
|
create = "create"
|
|
existing = "existing"
|
|
|
|
|
|
class MlflowServerSize(str, Enum):
|
|
small = "Small"
|
|
medium = "Medium"
|
|
large = "Large"
|
|
|
|
|
|
class Boto3SessionKwargs(TypedDict):
|
|
profile_name: str
|
|
region_name: str
|
|
|
|
|
|
class AwsConfig(BaseModel):
|
|
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
|
profile: str = "default"
|
|
|
|
@property
|
|
def boto3_session(self) -> Boto3SessionKwargs:
|
|
return {"profile_name": self.profile, "region_name": self.region}
|
|
|
|
|
|
DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds"
|
|
GENERATED_STACK_PREFIX = "qc-cli-mlops-"
|
|
|
|
|
|
class InfraConfig(BaseModel):
|
|
stack_name: str
|
|
|
|
@property
|
|
def effective_bootstrap_qualifier(self) -> str:
|
|
sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower())
|
|
if not sanitized:
|
|
return DEFAULT_BOOTSTRAP_QUALIFIER
|
|
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
|
suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower())
|
|
if suffix:
|
|
return f"q{suffix}"[:10]
|
|
return f"q{sanitized}"[:10]
|
|
|
|
@property
|
|
def effective_toolkit_stack_name(self) -> str:
|
|
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
|
suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX))
|
|
if suffix:
|
|
return f"{self.stack_name}-bootstrap"
|
|
return f"{self.stack_name}-bootstrap"
|
|
|
|
|
|
class S3Config(BaseModel):
|
|
bucket: str = "my-qc-mlops-bucket"
|
|
data_prefix: str = "data/"
|
|
model_prefix: str = "models/"
|
|
|
|
|
|
class TrainingConfig(BaseModel):
|
|
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
|
|
instance_count: int = 1
|
|
image_uri: str = ""
|
|
entry_point: str | None = None
|
|
source_dir: str | None = None
|
|
hyperparameters: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class SageMakerConfig(BaseModel):
|
|
role_name: str = ""
|
|
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
|
|
|
|
|
class MlflowConfig(BaseModel):
|
|
mode: MlflowMode = MlflowMode.disabled
|
|
tracking_server_name: str | None = None
|
|
experiment_name: str = "qc-cli-training"
|
|
registered_model_name: str = "qc-cli-model"
|
|
register_trained_models: bool = True
|
|
artifact_prefix: str = "mlflow/"
|
|
tracking_server_size: MlflowServerSize = MlflowServerSize.small
|
|
mlflow_version: str | None = None
|
|
automatic_model_registration: bool = False
|
|
weekly_maintenance_window_start: str | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def require_tracking_server_name(self) -> "MlflowConfig":
|
|
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
|
|
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
|
|
return self
|
|
|
|
|
|
class Config(BaseModel):
|
|
infra: InfraConfig
|
|
aws: AwsConfig = Field(default_factory=AwsConfig)
|
|
s3: S3Config = Field(default_factory=S3Config)
|
|
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
|
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|