update naming
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
@@ -24,23 +23,13 @@ class AwsConfig(BaseModel):
|
||||
|
||||
|
||||
class S3Config(BaseModel):
|
||||
bucket: str = "my-onnx-bucket"
|
||||
bucket: str = "my-qc-mlops-bucket"
|
||||
data_prefix: str = "data/"
|
||||
model_prefix: str = "models/"
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
|
||||
instance_count: int = 1
|
||||
image_uri: str = ""
|
||||
entry_point: str | None = None
|
||||
source_dir: str | None = None
|
||||
hyperparameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SageMakerConfig(BaseModel):
|
||||
role_name: str = "qc-cli-sagemaker-role"
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
|
||||
|
||||
class MlflowConfig(BaseModel):
|
||||
|
||||
@@ -33,7 +33,4 @@ def init(
|
||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
||||
|
||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||
console.print(
|
||||
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
||||
"before running other commands."
|
||||
)
|
||||
console.print("Edit it (especially [cyan]s3.bucket[/cyan]) before running other commands.")
|
||||
|
||||
Reference in New Issue
Block a user