command to start sagemaker training

include sample training
This commit is contained in:
2026-05-25 16:48:31 -04:00
parent 62ffe163e8
commit 0e728cc193
13 changed files with 796 additions and 5 deletions

131
src/aws/sagemaker.py Normal file
View File

@@ -0,0 +1,131 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
import boto3
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from mypy_boto3_sagemaker.type_defs import (
CreateTrainingJobRequestTypeDef,
ResourceConfigTypeDef,
TrainingJobSummaryTypeDef,
)
from src.config import Boto3SessionKwargs
@dataclass(frozen=True)
class TrainingJobRequest:
role_arn: str
image_uri: str
instance_type: TrainingInstanceTypeType
instance_count: int
s3_train_uri: str
s3_output_path: str
job_name: str
hyperparameters: dict[str, Any] = field(default_factory=dict)
entry_point: str | None = None
source_dir: str | None = None
@dataclass(frozen=True)
class TrainingJobStatus:
name: str
status: str
created: datetime | None
modified: datetime | None
model_artifacts: str | None
failure_reason: str | None
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
return boto3.Session(**session).client("sagemaker")
def _upload_source_dir(
session: Boto3SessionKwargs,
source_dir: str,
s3_output_path: str,
job_name: str,
) -> str:
import io
import tarfile
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
tar.add(source_dir, arcname=".")
buf.seek(0)
without_scheme = s3_output_path.removeprefix("s3://")
bucket, _, prefix = without_scheme.partition("/")
key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/")
boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key)
return f"s3://{bucket}/{key}"
def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str:
hp = {k: str(v) for k, v in job.hyperparameters.items()}
if job.source_dir:
s3_code_uri = _upload_source_dir(
session,
job.source_dir,
job.s3_output_path,
job.job_name,
)
hp["sagemaker_program"] = job.entry_point or "train.py"
hp["sagemaker_submit_directory"] = s3_code_uri
resource_config: ResourceConfigTypeDef = {
"InstanceType": job.instance_type,
"InstanceCount": job.instance_count,
"VolumeSizeInGB": 30,
}
request: CreateTrainingJobRequestTypeDef = {
"TrainingJobName": job.job_name,
"AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"},
"RoleArn": job.role_arn,
"InputDataConfig": [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": job.s3_train_uri,
"S3DataDistributionType": "FullyReplicated",
}
},
}
],
"OutputDataConfig": {"S3OutputPath": job.s3_output_path},
"ResourceConfig": resource_config,
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"HyperParameters": hp,
}
_sm(session).create_training_job(**request)
return job.job_name
def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus:
resp = _sm(session).describe_training_job(TrainingJobName=job_name)
return TrainingJobStatus(
name=resp["TrainingJobName"],
status=resp["TrainingJobStatus"],
created=resp.get("CreationTime"),
modified=resp.get("LastModifiedTime"),
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
failure_reason=resp.get("FailureReason"),
)
def list_training_jobs(
session: Boto3SessionKwargs,
max_results: int = 10,
) -> list[TrainingJobSummaryTypeDef]:
resp = _sm(session).list_training_jobs(
SortBy="CreationTime",
SortOrder="Descending",
MaxResults=max_results,
)
return list(resp["TrainingJobSummaries"])