command to start sagemaker training
include sample training
This commit is contained in:
17
src/aws/iam.py
Normal file
17
src/aws/iam.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from mypy_boto3_iam import IAMClient
|
||||
|
||||
|
||||
def _client(profile: str) -> IAMClient:
|
||||
return boto3.Session(profile_name=profile).client("iam")
|
||||
|
||||
|
||||
def get_role_arn(profile: str, role_name: str) -> str | None:
|
||||
client = _client(profile)
|
||||
try:
|
||||
return client.get_role(RoleName=role_name)["Role"]["Arn"]
|
||||
except ClientError as e:
|
||||
if e.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
||||
return None
|
||||
raise
|
||||
131
src/aws/sagemaker.py
Normal file
131
src/aws/sagemaker.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from mypy_boto3_sagemaker import SageMakerClient
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from mypy_boto3_sagemaker.type_defs import (
|
||||
CreateTrainingJobRequestTypeDef,
|
||||
ResourceConfigTypeDef,
|
||||
TrainingJobSummaryTypeDef,
|
||||
)
|
||||
|
||||
from src.config import Boto3SessionKwargs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingJobRequest:
|
||||
role_arn: str
|
||||
image_uri: str
|
||||
instance_type: TrainingInstanceTypeType
|
||||
instance_count: int
|
||||
s3_train_uri: str
|
||||
s3_output_path: str
|
||||
job_name: str
|
||||
hyperparameters: dict[str, Any] = field(default_factory=dict)
|
||||
entry_point: str | None = None
|
||||
source_dir: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingJobStatus:
|
||||
name: str
|
||||
status: str
|
||||
created: datetime | None
|
||||
modified: datetime | None
|
||||
model_artifacts: str | None
|
||||
failure_reason: str | None
|
||||
|
||||
|
||||
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
|
||||
return boto3.Session(**session).client("sagemaker")
|
||||
|
||||
|
||||
def _upload_source_dir(
|
||||
session: Boto3SessionKwargs,
|
||||
source_dir: str,
|
||||
s3_output_path: str,
|
||||
job_name: str,
|
||||
) -> str:
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
tar.add(source_dir, arcname=".")
|
||||
buf.seek(0)
|
||||
|
||||
without_scheme = s3_output_path.removeprefix("s3://")
|
||||
bucket, _, prefix = without_scheme.partition("/")
|
||||
key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/")
|
||||
|
||||
boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key)
|
||||
return f"s3://{bucket}/{key}"
|
||||
|
||||
|
||||
def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str:
|
||||
hp = {k: str(v) for k, v in job.hyperparameters.items()}
|
||||
|
||||
if job.source_dir:
|
||||
s3_code_uri = _upload_source_dir(
|
||||
session,
|
||||
job.source_dir,
|
||||
job.s3_output_path,
|
||||
job.job_name,
|
||||
)
|
||||
hp["sagemaker_program"] = job.entry_point or "train.py"
|
||||
hp["sagemaker_submit_directory"] = s3_code_uri
|
||||
|
||||
resource_config: ResourceConfigTypeDef = {
|
||||
"InstanceType": job.instance_type,
|
||||
"InstanceCount": job.instance_count,
|
||||
"VolumeSizeInGB": 30,
|
||||
}
|
||||
request: CreateTrainingJobRequestTypeDef = {
|
||||
"TrainingJobName": job.job_name,
|
||||
"AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"},
|
||||
"RoleArn": job.role_arn,
|
||||
"InputDataConfig": [
|
||||
{
|
||||
"ChannelName": "train",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": job.s3_train_uri,
|
||||
"S3DataDistributionType": "FullyReplicated",
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
"OutputDataConfig": {"S3OutputPath": job.s3_output_path},
|
||||
"ResourceConfig": resource_config,
|
||||
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
|
||||
"HyperParameters": hp,
|
||||
}
|
||||
_sm(session).create_training_job(**request)
|
||||
return job.job_name
|
||||
|
||||
|
||||
def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus:
|
||||
resp = _sm(session).describe_training_job(TrainingJobName=job_name)
|
||||
return TrainingJobStatus(
|
||||
name=resp["TrainingJobName"],
|
||||
status=resp["TrainingJobStatus"],
|
||||
created=resp.get("CreationTime"),
|
||||
modified=resp.get("LastModifiedTime"),
|
||||
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
|
||||
failure_reason=resp.get("FailureReason"),
|
||||
)
|
||||
|
||||
|
||||
def list_training_jobs(
|
||||
session: Boto3SessionKwargs,
|
||||
max_results: int = 10,
|
||||
) -> list[TrainingJobSummaryTypeDef]:
|
||||
resp = _sm(session).list_training_jobs(
|
||||
SortBy="CreationTime",
|
||||
SortOrder="Descending",
|
||||
MaxResults=max_results,
|
||||
)
|
||||
return list(resp["TrainingJobSummaries"])
|
||||
126
src/commands/train.py
Normal file
126
src/commands/train.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from datetime import datetime
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
|
||||
from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
|
||||
_STATUS_COLOR = {
|
||||
"Completed": "green",
|
||||
"Failed": "red",
|
||||
"InProgress": "yellow",
|
||||
"Stopping": "yellow",
|
||||
"Stopped": "dim",
|
||||
}
|
||||
|
||||
|
||||
def _config_dir(config_path: str) -> str:
|
||||
from pathlib import Path
|
||||
return str(Path(config_path).parent)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if not cfg.sagemaker.training.image_uri:
|
||||
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||
CONSOLE.print(
|
||||
"Find pre-built images at: "
|
||||
"https://aws.github.io/deep-learning-containers/reference/available_images"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if not role_arn:
|
||||
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
|
||||
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
|
||||
|
||||
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
|
||||
training_job = sm_ops.TrainingJobRequest(
|
||||
role_arn=role_arn,
|
||||
image_uri=cfg.sagemaker.training.image_uri,
|
||||
instance_type=cfg.sagemaker.training.instance_type,
|
||||
instance_count=cfg.sagemaker.training.instance_count,
|
||||
s3_train_uri=s3_train_uri,
|
||||
s3_output_path=s3_output,
|
||||
job_name=job_name,
|
||||
hyperparameters=cfg.sagemaker.training.hyperparameters,
|
||||
entry_point=cfg.sagemaker.training.entry_point,
|
||||
source_dir=cfg.sagemaker.training.source_dir,
|
||||
)
|
||||
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
|
||||
|
||||
state_ops.write_state(_config_dir(config), last_training_job=job_name)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Show training job status."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if not job_name:
|
||||
job_name = state_ops.get_last_training_job(_config_dir(config))
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""List recent training jobs."""
|
||||
cfg = load_cfg(config)
|
||||
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
|
||||
|
||||
if not jobs:
|
||||
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Training Jobs")
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("Status")
|
||||
table.add_column("Created")
|
||||
|
||||
for job in jobs:
|
||||
status_value = str(job["TrainingJobStatus"])
|
||||
color = _STATUS_COLOR.get(status_value, "white")
|
||||
table.add_row(
|
||||
str(job["TrainingJobName"]),
|
||||
f"[{color}]{status_value}[/{color}]",
|
||||
str(job.get("CreationTime", "")),
|
||||
)
|
||||
|
||||
CONSOLE.print(table)
|
||||
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
@@ -17,10 +18,19 @@ class MlflowServerSize(str, Enum):
|
||||
large = "Large"
|
||||
|
||||
|
||||
class Boto3SessionKwargs(TypedDict):
|
||||
profile_name: str
|
||||
region_name: str
|
||||
|
||||
|
||||
class AwsConfig(BaseModel):
|
||||
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
||||
profile: str = "default"
|
||||
|
||||
@property
|
||||
def boto3_session(self) -> Boto3SessionKwargs:
|
||||
return {"profile_name": self.profile, "region_name": self.region}
|
||||
|
||||
|
||||
class S3Config(BaseModel):
|
||||
bucket: str = "my-qc-mlops-bucket"
|
||||
@@ -28,8 +38,18 @@ class S3Config(BaseModel):
|
||||
model_prefix: str = "models/"
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
|
||||
instance_count: int = 1
|
||||
image_uri: str = ""
|
||||
entry_point: str | None = None
|
||||
source_dir: str | None = None
|
||||
hyperparameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SageMakerConfig(BaseModel):
|
||||
role_name: str = "qc-cli-sagemaker-role"
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
|
||||
|
||||
class MlflowConfig(BaseModel):
|
||||
|
||||
@@ -6,7 +6,7 @@ from rich.console import Console
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||
|
||||
from src.aws import s3 as s3_ops
|
||||
from src.commands import infra
|
||||
from src.commands import infra, train
|
||||
from src.commands.utils import CONFIG_OPT, load_cfg
|
||||
from src.config import Config
|
||||
|
||||
@@ -15,6 +15,7 @@ app = typer.Typer(
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(infra.app, name="infra")
|
||||
app.add_typer(train.app, name="train")
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -36,7 +37,10 @@ def init(
|
||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
||||
|
||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||
console.print("Edit it (especially [cyan]s3.bucket[/cyan]) before running other commands.")
|
||||
console.print(
|
||||
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
||||
"before running other commands."
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
30
src/state.py
Normal file
30
src/state.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
STATE_FILE = ".qc-cli.json"
|
||||
|
||||
|
||||
def _path(config_dir: str) -> Path:
|
||||
return Path(config_dir) / STATE_FILE
|
||||
|
||||
|
||||
def read_state(config_dir: str = ".") -> dict[str, Any]:
|
||||
path = _path(config_dir)
|
||||
if not path.exists():
|
||||
return {}
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_state(config_dir: str = ".", **updates: str | None) -> None:
|
||||
path = _path(config_dir)
|
||||
state = read_state(config_dir)
|
||||
state.update(updates)
|
||||
with open(path, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
|
||||
def get_last_training_job(config_dir: str = ".") -> str | None:
|
||||
value = read_state(config_dir).get("last_training_job")
|
||||
return str(value) if value else None
|
||||
Reference in New Issue
Block a user