from dataclasses import dataclass, field from datetime import datetime from typing import Any import boto3 from mypy_boto3_sagemaker import SageMakerClient from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType from mypy_boto3_sagemaker.type_defs import ( CreateTrainingJobRequestTypeDef, ResourceConfigTypeDef, TrainingJobSummaryTypeDef, ) from src.config import Boto3SessionKwargs @dataclass(frozen=True) class TrainingJobRequest: role_arn: str image_uri: str instance_type: TrainingInstanceTypeType instance_count: int s3_train_uri: str s3_output_path: str job_name: str hyperparameters: dict[str, Any] = field(default_factory=dict) entry_point: str | None = None source_dir: str | None = None @dataclass(frozen=True) class TrainingJobStatus: name: str status: str created: datetime | None modified: datetime | None model_artifacts: str | None failure_reason: str | None raw: dict[str, Any] = field(default_factory=dict) def _sm(session: Boto3SessionKwargs) -> SageMakerClient: return boto3.Session(**session).client("sagemaker") def _upload_source_dir( session: Boto3SessionKwargs, source_dir: str, s3_output_path: str, job_name: str, ) -> str: import io import tarfile buf = io.BytesIO() with tarfile.open(fileobj=buf, mode="w:gz") as tar: tar.add(source_dir, arcname=".") buf.seek(0) without_scheme = s3_output_path.removeprefix("s3://") bucket, _, prefix = without_scheme.partition("/") key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/") boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key) return f"s3://{bucket}/{key}" def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str: hp = {k: str(v) for k, v in job.hyperparameters.items()} if job.source_dir: s3_code_uri = _upload_source_dir( session, job.source_dir, job.s3_output_path, job.job_name, ) hp["sagemaker_program"] = job.entry_point or "train.py" hp["sagemaker_submit_directory"] = s3_code_uri resource_config: ResourceConfigTypeDef = { "InstanceType": job.instance_type, "InstanceCount": job.instance_count, "VolumeSizeInGB": 30, } request: CreateTrainingJobRequestTypeDef = { "TrainingJobName": job.job_name, "AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"}, "RoleArn": job.role_arn, "InputDataConfig": [ { "ChannelName": "train", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": job.s3_train_uri, "S3DataDistributionType": "FullyReplicated", } }, } ], "OutputDataConfig": {"S3OutputPath": job.s3_output_path}, "ResourceConfig": resource_config, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "HyperParameters": hp, } _sm(session).create_training_job(**request) return job.job_name def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus: resp = _sm(session).describe_training_job(TrainingJobName=job_name) return TrainingJobStatus( name=resp["TrainingJobName"], status=resp["TrainingJobStatus"], created=resp.get("CreationTime"), modified=resp.get("LastModifiedTime"), model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"), failure_reason=resp.get("FailureReason"), raw=dict(resp), ) def get_model_artifacts(region: str, profile: str, job_name: str) -> str: resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job( TrainingJobName=job_name ) artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts") if not artifact: raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.") return str(artifact) def list_training_jobs( session: Boto3SessionKwargs, max_results: int = 10, ) -> list[TrainingJobSummaryTypeDef]: resp = _sm(session).list_training_jobs( SortBy="CreationTime", SortOrder="Descending", MaxResults=max_results, ) return list(resp["TrainingJobSummaries"])