144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import boto3
|
|
from mypy_boto3_sagemaker import SageMakerClient
|
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
|
from mypy_boto3_sagemaker.type_defs import (
|
|
CreateTrainingJobRequestTypeDef,
|
|
ResourceConfigTypeDef,
|
|
TrainingJobSummaryTypeDef,
|
|
)
|
|
|
|
from src.config import Boto3SessionKwargs
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TrainingJobRequest:
|
|
role_arn: str
|
|
image_uri: str
|
|
instance_type: TrainingInstanceTypeType
|
|
instance_count: int
|
|
s3_train_uri: str
|
|
s3_output_path: str
|
|
job_name: str
|
|
hyperparameters: dict[str, Any] = field(default_factory=dict)
|
|
entry_point: str | None = None
|
|
source_dir: str | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TrainingJobStatus:
|
|
name: str
|
|
status: str
|
|
created: datetime | None
|
|
modified: datetime | None
|
|
model_artifacts: str | None
|
|
failure_reason: str | None
|
|
raw: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
|
|
return boto3.Session(**session).client("sagemaker")
|
|
|
|
|
|
def _upload_source_dir(
|
|
session: Boto3SessionKwargs,
|
|
source_dir: str,
|
|
s3_output_path: str,
|
|
job_name: str,
|
|
) -> str:
|
|
import io
|
|
import tarfile
|
|
|
|
buf = io.BytesIO()
|
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
|
tar.add(source_dir, arcname=".")
|
|
buf.seek(0)
|
|
|
|
without_scheme = s3_output_path.removeprefix("s3://")
|
|
bucket, _, prefix = without_scheme.partition("/")
|
|
key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/")
|
|
|
|
boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key)
|
|
return f"s3://{bucket}/{key}"
|
|
|
|
|
|
def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str:
|
|
hp = {k: str(v) for k, v in job.hyperparameters.items()}
|
|
|
|
if job.source_dir:
|
|
s3_code_uri = _upload_source_dir(
|
|
session,
|
|
job.source_dir,
|
|
job.s3_output_path,
|
|
job.job_name,
|
|
)
|
|
hp["sagemaker_program"] = job.entry_point or "train.py"
|
|
hp["sagemaker_submit_directory"] = s3_code_uri
|
|
|
|
resource_config: ResourceConfigTypeDef = {
|
|
"InstanceType": job.instance_type,
|
|
"InstanceCount": job.instance_count,
|
|
"VolumeSizeInGB": 30,
|
|
}
|
|
request: CreateTrainingJobRequestTypeDef = {
|
|
"TrainingJobName": job.job_name,
|
|
"AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"},
|
|
"RoleArn": job.role_arn,
|
|
"InputDataConfig": [
|
|
{
|
|
"ChannelName": "train",
|
|
"DataSource": {
|
|
"S3DataSource": {
|
|
"S3DataType": "S3Prefix",
|
|
"S3Uri": job.s3_train_uri,
|
|
"S3DataDistributionType": "FullyReplicated",
|
|
}
|
|
},
|
|
}
|
|
],
|
|
"OutputDataConfig": {"S3OutputPath": job.s3_output_path},
|
|
"ResourceConfig": resource_config,
|
|
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
|
|
"HyperParameters": hp,
|
|
}
|
|
_sm(session).create_training_job(**request)
|
|
return job.job_name
|
|
|
|
|
|
def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus:
|
|
resp = _sm(session).describe_training_job(TrainingJobName=job_name)
|
|
return TrainingJobStatus(
|
|
name=resp["TrainingJobName"],
|
|
status=resp["TrainingJobStatus"],
|
|
created=resp.get("CreationTime"),
|
|
modified=resp.get("LastModifiedTime"),
|
|
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
|
|
failure_reason=resp.get("FailureReason"),
|
|
raw=dict(resp),
|
|
)
|
|
|
|
|
|
def get_model_artifacts(region: str, profile: str, job_name: str) -> str:
|
|
resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job(
|
|
TrainingJobName=job_name
|
|
)
|
|
artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts")
|
|
if not artifact:
|
|
raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.")
|
|
return str(artifact)
|
|
|
|
|
|
def list_training_jobs(
|
|
session: Boto3SessionKwargs,
|
|
max_results: int = 10,
|
|
) -> list[TrainingJobSummaryTypeDef]:
|
|
resp = _sm(session).list_training_jobs(
|
|
SortBy="CreationTime",
|
|
SortOrder="Descending",
|
|
MaxResults=max_results,
|
|
)
|
|
return list(resp["TrainingJobSummaries"])
|