create AWS infra

This commit is contained in:
2026-05-15 10:26:43 -04:00
parent 6563b4cc4b
commit 1bc5052d22
21 changed files with 1502 additions and 0 deletions

66
src/config.py Normal file
View File

@@ -0,0 +1,66 @@
from enum import Enum
from typing import Any, Literal
from mypy_boto3_s3.literals import BucketLocationConstraintType
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator
class MlflowMode(str, Enum):
disabled = "disabled"
create = "create"
existing = "existing"
class MlflowServerSize(str, Enum):
small = "Small"
medium = "Medium"
large = "Large"
class AwsConfig(BaseModel):
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
profile: str = "default"
class S3Config(BaseModel):
bucket: str = "my-onnx-bucket"
data_prefix: str = "data/"
model_prefix: str = "models/"
class TrainingConfig(BaseModel):
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
instance_count: int = 1
image_uri: str = ""
entry_point: str | None = None
source_dir: str | None = None
hyperparameters: dict[str, Any] = Field(default_factory=dict)
class SageMakerConfig(BaseModel):
role_name: str = "qai-cli-sagemaker-role"
training: TrainingConfig = Field(default_factory=TrainingConfig)
class MlflowConfig(BaseModel):
mode: MlflowMode = MlflowMode.disabled
tracking_server_name: str | None = None
artifact_prefix: str = "mlflow/"
tracking_server_size: MlflowServerSize = MlflowServerSize.small
mlflow_version: str | None = None
automatic_model_registration: bool = False
weekly_maintenance_window_start: str | None = None
@model_validator(mode="after")
def require_tracking_server_name(self) -> "MlflowConfig":
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
return self
class Config(BaseModel):
aws: AwsConfig = Field(default_factory=AwsConfig)
s3: S3Config = Field(default_factory=S3Config)
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)