Files
qai-cli/src/config.py

108 lines
3.4 KiB
Python

import re
from enum import Enum
from typing import Any, Literal, TypedDict
from mypy_boto3_s3.literals import BucketLocationConstraintType
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator
class MlflowMode(str, Enum):
disabled = "disabled"
create = "create"
existing = "existing"
class MlflowServerSize(str, Enum):
small = "Small"
medium = "Medium"
large = "Large"
class Boto3SessionKwargs(TypedDict):
profile_name: str
region_name: str
class AwsConfig(BaseModel):
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
profile: str = "default"
@property
def boto3_session(self) -> Boto3SessionKwargs:
return {"profile_name": self.profile, "region_name": self.region}
DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds"
GENERATED_STACK_PREFIX = "qc-cli-mlops-"
class InfraConfig(BaseModel):
stack_name: str
@property
def effective_bootstrap_qualifier(self) -> str:
sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower())
if not sanitized:
return DEFAULT_BOOTSTRAP_QUALIFIER
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower())
if suffix:
return f"q{suffix}"[:10]
return f"q{sanitized}"[:10]
@property
def effective_toolkit_stack_name(self) -> str:
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX))
if suffix:
return f"{self.stack_name}-bootstrap"
return f"{self.stack_name}-bootstrap"
class S3Config(BaseModel):
bucket: str = "my-qc-mlops-bucket"
data_prefix: str = "data/"
model_prefix: str = "models/"
class TrainingConfig(BaseModel):
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
instance_count: int = 1
image_uri: str = ""
entry_point: str | None = None
source_dir: str | None = None
hyperparameters: dict[str, Any] = Field(default_factory=dict)
class SageMakerConfig(BaseModel):
role_name: str = ""
training: TrainingConfig = Field(default_factory=TrainingConfig)
class MlflowConfig(BaseModel):
mode: MlflowMode = MlflowMode.disabled
tracking_server_name: str | None = None
experiment_name: str = "qc-cli-training"
registered_model_name: str = "qc-cli-model"
register_trained_models: bool = True
artifact_prefix: str = "mlflow/"
tracking_server_size: MlflowServerSize = MlflowServerSize.small
mlflow_version: str | None = None
automatic_model_registration: bool = False
weekly_maintenance_window_start: str | None = None
@model_validator(mode="after")
def require_tracking_server_name(self) -> "MlflowConfig":
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
return self
class Config(BaseModel):
infra: InfraConfig
aws: AwsConfig = Field(default_factory=AwsConfig)
s3: S3Config = Field(default_factory=S3Config)
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)