create AWS infra
This commit is contained in:
66
src/config.py
Normal file
66
src/config.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MlflowMode(str, Enum):
|
||||
disabled = "disabled"
|
||||
create = "create"
|
||||
existing = "existing"
|
||||
|
||||
|
||||
class MlflowServerSize(str, Enum):
|
||||
small = "Small"
|
||||
medium = "Medium"
|
||||
large = "Large"
|
||||
|
||||
|
||||
class AwsConfig(BaseModel):
|
||||
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
||||
profile: str = "default"
|
||||
|
||||
|
||||
class S3Config(BaseModel):
|
||||
bucket: str = "my-onnx-bucket"
|
||||
data_prefix: str = "data/"
|
||||
model_prefix: str = "models/"
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
|
||||
instance_count: int = 1
|
||||
image_uri: str = ""
|
||||
entry_point: str | None = None
|
||||
source_dir: str | None = None
|
||||
hyperparameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SageMakerConfig(BaseModel):
|
||||
role_name: str = "qai-cli-sagemaker-role"
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
|
||||
|
||||
class MlflowConfig(BaseModel):
|
||||
mode: MlflowMode = MlflowMode.disabled
|
||||
tracking_server_name: str | None = None
|
||||
artifact_prefix: str = "mlflow/"
|
||||
tracking_server_size: MlflowServerSize = MlflowServerSize.small
|
||||
mlflow_version: str | None = None
|
||||
automatic_model_registration: bool = False
|
||||
weekly_maintenance_window_start: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_tracking_server_name(self) -> "MlflowConfig":
|
||||
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
|
||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
|
||||
return self
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
aws: AwsConfig = Field(default_factory=AwsConfig)
|
||||
s3: S3Config = Field(default_factory=S3Config)
|
||||
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
||||
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|
||||
Reference in New Issue
Block a user