From 0e728cc193ae59d6961fafb5bc5dbb937a9b1fe9 Mon Sep 17 00:00:00 2001 From: slalom Date: Mon, 25 May 2026 16:48:31 -0400 Subject: [PATCH] command to start sagemaker training include sample training --- .gitignore | 3 +- README.md | 30 ++- examples/training/README.md | 90 +++++++++ examples/training/download_flower_photos.sh | 40 ++++ examples/training/run_training.sh | 111 +++++++++++ examples/training/source/requirements.txt | 1 + examples/training/source/train.py | 192 ++++++++++++++++++++ src/aws/iam.py | 17 ++ src/aws/sagemaker.py | 131 +++++++++++++ src/commands/train.py | 126 +++++++++++++ src/config.py | 22 ++- src/main.py | 8 +- src/state.py | 30 +++ 13 files changed, 796 insertions(+), 5 deletions(-) create mode 100644 examples/training/README.md create mode 100755 examples/training/download_flower_photos.sh create mode 100755 examples/training/run_training.sh create mode 100644 examples/training/source/requirements.txt create mode 100644 examples/training/source/train.py create mode 100644 src/aws/iam.py create mode 100644 src/aws/sagemaker.py create mode 100644 src/commands/train.py create mode 100644 src/state.py diff --git a/.gitignore b/.gitignore index 7d759a9..9932880 100644 --- a/.gitignore +++ b/.gitignore @@ -220,4 +220,5 @@ __marimo__/ .venv/ config.yaml cdk.out/ -.qc-cli-infra* \ No newline at end of file +.qc-cli*.json +examples/*/data/ diff --git a/README.md b/README.md index 6d2050f..0c0cd66 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,16 @@ qc-cli --help # 1. Create config.yaml in the current directory qc-cli init -# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.role_name +# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.training.image_uri # 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role). # This is the step that requires the AWS CDK CLI. qc-cli infra setup + +# 4. Upload training data, then submit a SageMaker training job. +qc-cli upload ./my-dataset +qc-cli train start +qc-cli train status ``` ## Configuration @@ -51,8 +56,17 @@ s3: sagemaker: role_name: qc-cli-sagemaker-role + training: + image_uri: "" # ECR URI for your training container + instance_type: ml.m5.xlarge + instance_count: 1 + entry_point: null # Optional: script inside source_dir + source_dir: null # Optional: local dir packaged and uploaded automatically + hyperparameters: {} ``` +`hyperparameters` is a flat map of values passed to the training container. Valid keys depend on the selected training image and entry point. + To provision an MLflow tracking server, set: ```yaml @@ -101,6 +115,19 @@ qc-cli upload --s3-key Upload a file to a custom S3 key Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads default to `s3:////`. Directory uploads are recursive, preserve paths relative to the uploaded directory, and place files under `s3:////`. +### `train` + +``` +qc-cli train start Submit a SageMaker training job +qc-cli train status [job-name] Show job status; defaults to the last submitted job +qc-cli train list List recent training jobs +qc-cli train list --limit 3 Show a custom number of recent jobs +``` + +`train start` uses `s3:////` as the training channel and writes outputs under `s3:////`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. + +The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. + ## AWS permissions required The IAM user or role running the CLI needs: @@ -111,6 +138,7 @@ The IAM user or role running the CLI needs: | CreateRole, GetRole, DeleteRole, AttachRolePolicy, DetachRolePolicy | IAM | | CreateStack, UpdateStack, DeleteStack, DescribeStacks, DescribeStackEvents | CloudFormation | | GetCallerIdentity | STS | +| CreateTrainingJob, DescribeTrainingJob, ListTrainingJobs | SageMaker AI | | CreateMlflowTrackingServer, DescribeMlflowTrackingServer, DeleteMlflowTrackingServer | SageMaker AI, when `mlflow.mode` is `create` or `existing` | `AdministratorAccess` covers all of the above. diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 0000000..607eaa4 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,90 @@ +# SageMaker Training Example + +This example downloads a small image-classification dataset, uploads it through `qc-cli`, and submits a live SageMaker training job. + +## Prerequisites + +- AWS credentials configured for the profile in `config.yaml` +- Infrastructure already deployed with `qc-cli infra setup` +- `config.yaml` updated with: + +```yaml +s3: + bucket: your-bucket-name + +sagemaker: + role_name: + training: + image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1 + instance_type: ml.m4.xlarge + instance_count: 1 + source_dir: examples/training/source + entry_point: train.py + hyperparameters: + epochs: 1 + batch-size: 32 + learning-rate: 0.001 + image-size: 160 + validation-split: 0.2 +``` + +## Training Hyperparameters + +Values under `sagemaker.training.hyperparameters` are passed to the training entry point as command-line arguments. For this example, they map to arguments defined in [source/train.py](source/train.py). + +Supported by this example: + +| Name | Type | Default | Description | +|---|---:|---:|---| +| `epochs` | int | `1` | Number of training epochs. | +| `batch-size` | int | `32` | Images per training batch. | +| `learning-rate` | float | `0.001` | Adam optimizer learning rate. | +| `image-size` | int | `160` | Resize images to square `image-size x image-size`. | +| `validation-split` | float | `0.2` | Fraction of data used for validation. | +| `max-samples` | int | `0` | Optional cap for smoke tests; `0` means use all images. | +| `seed` | int | `13` | Random seed for reproducible splitting. | +| `num-workers` | int | `2` | DataLoader worker count. | + +Do not set `train-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`. + +## 1. Download The Dataset + +```bash +bash examples/training/download_flower_photos.sh +``` + +This creates: + +```text +examples/training/data/flower_photos_sagemaker/ + daisy/ + dandelion/ + roses/ + sunflowers/ + tulips/ +``` + +## 2. Run Training + +Run the training script and wait until it finishes: + +```bash +bash examples/training/run_training.sh --config config.yaml --wait +``` + +Use a dataset that is already uploaded to `s3.data_prefix`: + +```bash +bash examples/training/run_training.sh \ + --config config.yaml \ + --skip-upload \ + --wait +``` + +## Notes + +- The default dataset path is `examples/training/data/flower_photos_sagemaker`. +- Uploaded data uses the `s3.bucket` and `s3.data_prefix` values from `config.yaml`. +- Training artifacts are written under `s3:////`. +- The SageMaker `model.tar.gz` contains `model.onnx`, `model.pt`, `class_to_idx.json`, and `metrics.json`. +- SageMaker packages `examples/training/source`, installs `requirements.txt`, and runs `train.py`. diff --git a/examples/training/download_flower_photos.sh b/examples/training/download_flower_photos.sh new file mode 100755 index 0000000..5df0429 --- /dev/null +++ b/examples/training/download_flower_photos.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +set -euo pipefail + +DATASET_URL="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" +DEST_DIR="${1:-examples/training/data}" +ARCHIVE_PATH="${DEST_DIR}/flower_photos.tgz" +RAW_DATASET_DIR="${DEST_DIR}/flower_photos" +DATASET_DIR="${DEST_DIR}/flower_photos_sagemaker" +CLASS_NAMES=("daisy" "dandelion" "roses" "sunflowers" "tulips") + +mkdir -p "${DEST_DIR}" + +if [[ -d "${DATASET_DIR}" ]]; then + echo "Dataset already exists: ${DATASET_DIR}" + echo "Use this path with run_training.py:" + echo " ${DATASET_DIR}" + exit 0 +fi + +echo "Downloading TensorFlow flower_photos dataset..." +if command -v curl >/dev/null 2>&1; then + curl -L "${DATASET_URL}" -o "${ARCHIVE_PATH}" +elif command -v wget >/dev/null 2>&1; then + wget -O "${ARCHIVE_PATH}" "${DATASET_URL}" +else + echo "Either curl or wget is required." >&2 + exit 1 +fi + +echo "Extracting dataset..." +tar -xzf "${ARCHIVE_PATH}" -C "${DEST_DIR}" + +echo "Preparing SageMaker directory layout..." +mkdir -p "${DATASET_DIR}" +for class_name in "${CLASS_NAMES[@]}"; do + cp -R "${RAW_DATASET_DIR}/${class_name}" "${DATASET_DIR}/${class_name}" +done + +echo "Dataset ready: ${DATASET_DIR}" +find "${DATASET_DIR}" -mindepth 1 -maxdepth 1 -type d -print | sort diff --git a/examples/training/run_training.sh b/examples/training/run_training.sh new file mode 100755 index 0000000..8aa5575 --- /dev/null +++ b/examples/training/run_training.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash +set -euo pipefail + +CONFIG_PATH="config.yaml" +DATASET_DIR="examples/training/data/flower_photos_sagemaker" +WAIT=false +SKIP_UPLOAD=false +POLL_SECONDS=60 + +usage() { + cat <&2 + usage >&2 + exit 1 + ;; + esac +done + +if [[ ! -f "${CONFIG_PATH}" ]]; then + echo "Config not found: ${CONFIG_PATH}" >&2 + exit 1 +fi + +if [[ "${SKIP_UPLOAD}" == false && ! -d "${DATASET_DIR}" ]]; then + echo "Dataset not found: ${DATASET_DIR}" >&2 + echo "Run: bash examples/training/download_flower_photos.sh" >&2 + exit 1 +fi + +run() { + echo "+ $*" + "$@" +} + +run uv run qc-cli infra status --config "${CONFIG_PATH}" + +if [[ "${SKIP_UPLOAD}" == false ]]; then + run uv run qc-cli upload "${DATASET_DIR}" --config "${CONFIG_PATH}" +fi + +TRAIN_OUTPUT="$(uv run qc-cli train start --config "${CONFIG_PATH}")" +echo "${TRAIN_OUTPUT}" + +JOB_NAME="$(printf '%s\n' "${TRAIN_OUTPUT}" | grep -Eo 'qc-cli-[0-9]{8}-[0-9]{6}' | tail -n 1)" +if [[ -z "${JOB_NAME}" ]]; then + echo "Could not find training job name in qc-cli output." >&2 + exit 1 +fi + +echo "Submitted SageMaker training job: ${JOB_NAME}" + +if [[ "${WAIT}" == false ]]; then + run uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}" + exit 0 +fi + +while true; do + STATUS_OUTPUT="$(uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}")" + echo "${STATUS_OUTPUT}" + + if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Completed'; then + echo "Training completed successfully." + exit 0 + fi + + if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Failed'; then + echo "Training failed." >&2 + exit 1 + fi + + if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Stopped'; then + echo "Training stopped." >&2 + exit 1 + fi + + sleep "${POLL_SECONDS}" +done diff --git a/examples/training/source/requirements.txt b/examples/training/source/requirements.txt new file mode 100644 index 0000000..90dcba3 --- /dev/null +++ b/examples/training/source/requirements.txt @@ -0,0 +1 @@ +onnx==1.21.0 diff --git a/examples/training/source/train.py b/examples/training/source/train.py new file mode 100644 index 0000000..51c823e --- /dev/null +++ b/examples/training/source/train.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +"""SageMaker entry point for CPU image-classification training.""" + +from __future__ import annotations + +import argparse +import json +import os +import random +from pathlib import Path + +import torch +from torch import nn +from torch.utils.data import DataLoader, Subset, random_split +from torchvision import datasets, transforms + + +class SmallImageClassifier(nn.Module): + def __init__(self, class_count: int) -> None: + super().__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((1, 1)), + ) + self.classifier = nn.Linear(64, class_count) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = torch.flatten(x, 1) + return self.classifier(x) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--learning-rate", type=float, default=0.001) + parser.add_argument("--image-size", type=int, default=160) + parser.add_argument("--validation-split", type=float, default=0.2) + parser.add_argument("--max-samples", type=int, default=0) + parser.add_argument("--seed", type=int, default=13) + parser.add_argument("--num-workers", type=int, default=2) + parser.add_argument("--train-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train")) + parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")) + return parser.parse_args() + + +def build_datasets(args: argparse.Namespace) -> tuple[Subset, Subset, dict[str, int]]: + transform = transforms.Compose( + [ + transforms.Resize((args.image_size, args.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + dataset = datasets.ImageFolder(args.train_dir, transform=transform) + if len(dataset.classes) < 2: + raise ValueError(f"Expected at least two classes in {args.train_dir}. Found: {dataset.classes}") + + if args.max_samples > 0 and args.max_samples < len(dataset): + indices = list(range(len(dataset))) + random.Random(args.seed).shuffle(indices) + dataset = Subset(dataset, indices[: args.max_samples]) + + validation_size = max(1, int(len(dataset) * args.validation_split)) + train_size = len(dataset) - validation_size + if train_size < 1: + raise ValueError("Not enough images to create a train/validation split.") + + generator = torch.Generator().manual_seed(args.seed) + train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=generator) + return train_dataset, validation_dataset, getattr(dataset, "dataset", dataset).class_to_idx + + +def run_epoch( + model: nn.Module, + data_loader: DataLoader, + criterion: nn.Module, + optimizer: torch.optim.Optimizer | None, + device: torch.device, +) -> tuple[float, float]: + training = optimizer is not None + model.train(training) + + total_loss = 0.0 + total_correct = 0 + total_examples = 0 + + for images, labels in data_loader: + images = images.to(device) + labels = labels.to(device) + + with torch.set_grad_enabled(training): + logits = model(images) + loss = criterion(logits, labels) + + if training: + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() * images.size(0) + total_correct += (logits.argmax(dim=1) == labels).sum().item() + total_examples += images.size(0) + + return total_loss / total_examples, total_correct / total_examples + + +def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None: + model.eval() + dummy_input = torch.randn(1, 3, image_size, image_size) + torch.onnx.export( + model, + dummy_input, + model_dir / "model.onnx", + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["input"], + output_names=["logits"], + dynamic_axes={ + "input": {0: "batch_size"}, + "logits": {0: "batch_size"}, + }, + ) + + +def main() -> None: + args = parse_args() + random.seed(args.seed) + torch.manual_seed(args.seed) + + train_dataset, validation_dataset, class_to_idx = build_datasets(args) + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + ) + validation_loader = DataLoader( + validation_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SmallImageClassifier(class_count=len(class_to_idx)).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + print(f"Training on {device}. Classes: {sorted(class_to_idx)}") + metrics = [] + for epoch in range(1, args.epochs + 1): + train_loss, train_accuracy = run_epoch(model, train_loader, criterion, optimizer, device) + validation_loss, validation_accuracy = run_epoch(model, validation_loader, criterion, None, device) + epoch_metrics = { + "epoch": epoch, + "train_loss": train_loss, + "train_accuracy": train_accuracy, + "validation_loss": validation_loss, + "validation_accuracy": validation_accuracy, + } + metrics.append(epoch_metrics) + print(json.dumps(epoch_metrics, sort_keys=True)) + + model_dir = Path(args.model_dir) + model_dir.mkdir(parents=True, exist_ok=True) + torch.save( + { + "model_state_dict": model.cpu().state_dict(), + "class_to_idx": class_to_idx, + "image_size": args.image_size, + }, + model_dir / "model.pt", + ) + export_onnx(model, model_dir, args.image_size) + (model_dir / "class_to_idx.json").write_text(json.dumps(class_to_idx, indent=2), encoding="utf-8") + (model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8") + print(f"Saved model artifacts to {model_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/aws/iam.py b/src/aws/iam.py new file mode 100644 index 0000000..9422029 --- /dev/null +++ b/src/aws/iam.py @@ -0,0 +1,17 @@ +import boto3 +from botocore.exceptions import ClientError +from mypy_boto3_iam import IAMClient + + +def _client(profile: str) -> IAMClient: + return boto3.Session(profile_name=profile).client("iam") + + +def get_role_arn(profile: str, role_name: str) -> str | None: + client = _client(profile) + try: + return client.get_role(RoleName=role_name)["Role"]["Arn"] + except ClientError as e: + if e.response.get("Error", {}).get("Code") == "NoSuchEntity": + return None + raise diff --git a/src/aws/sagemaker.py b/src/aws/sagemaker.py new file mode 100644 index 0000000..158a8ca --- /dev/null +++ b/src/aws/sagemaker.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import boto3 +from mypy_boto3_sagemaker import SageMakerClient +from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType +from mypy_boto3_sagemaker.type_defs import ( + CreateTrainingJobRequestTypeDef, + ResourceConfigTypeDef, + TrainingJobSummaryTypeDef, +) + +from src.config import Boto3SessionKwargs + + +@dataclass(frozen=True) +class TrainingJobRequest: + role_arn: str + image_uri: str + instance_type: TrainingInstanceTypeType + instance_count: int + s3_train_uri: str + s3_output_path: str + job_name: str + hyperparameters: dict[str, Any] = field(default_factory=dict) + entry_point: str | None = None + source_dir: str | None = None + + +@dataclass(frozen=True) +class TrainingJobStatus: + name: str + status: str + created: datetime | None + modified: datetime | None + model_artifacts: str | None + failure_reason: str | None + + +def _sm(session: Boto3SessionKwargs) -> SageMakerClient: + return boto3.Session(**session).client("sagemaker") + + +def _upload_source_dir( + session: Boto3SessionKwargs, + source_dir: str, + s3_output_path: str, + job_name: str, +) -> str: + import io + import tarfile + + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + tar.add(source_dir, arcname=".") + buf.seek(0) + + without_scheme = s3_output_path.removeprefix("s3://") + bucket, _, prefix = without_scheme.partition("/") + key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/") + + boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key) + return f"s3://{bucket}/{key}" + + +def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str: + hp = {k: str(v) for k, v in job.hyperparameters.items()} + + if job.source_dir: + s3_code_uri = _upload_source_dir( + session, + job.source_dir, + job.s3_output_path, + job.job_name, + ) + hp["sagemaker_program"] = job.entry_point or "train.py" + hp["sagemaker_submit_directory"] = s3_code_uri + + resource_config: ResourceConfigTypeDef = { + "InstanceType": job.instance_type, + "InstanceCount": job.instance_count, + "VolumeSizeInGB": 30, + } + request: CreateTrainingJobRequestTypeDef = { + "TrainingJobName": job.job_name, + "AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"}, + "RoleArn": job.role_arn, + "InputDataConfig": [ + { + "ChannelName": "train", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": job.s3_train_uri, + "S3DataDistributionType": "FullyReplicated", + } + }, + } + ], + "OutputDataConfig": {"S3OutputPath": job.s3_output_path}, + "ResourceConfig": resource_config, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "HyperParameters": hp, + } + _sm(session).create_training_job(**request) + return job.job_name + + +def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus: + resp = _sm(session).describe_training_job(TrainingJobName=job_name) + return TrainingJobStatus( + name=resp["TrainingJobName"], + status=resp["TrainingJobStatus"], + created=resp.get("CreationTime"), + modified=resp.get("LastModifiedTime"), + model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"), + failure_reason=resp.get("FailureReason"), + ) + + +def list_training_jobs( + session: Boto3SessionKwargs, + max_results: int = 10, +) -> list[TrainingJobSummaryTypeDef]: + resp = _sm(session).list_training_jobs( + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=max_results, + ) + return list(resp["TrainingJobSummaries"]) diff --git a/src/commands/train.py b/src/commands/train.py new file mode 100644 index 0000000..5de70be --- /dev/null +++ b/src/commands/train.py @@ -0,0 +1,126 @@ +from datetime import datetime + +import typer +from rich.table import Table + +from src import state as state_ops +from src.aws import iam +from src.aws import sagemaker as sm_ops +from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg + +app = typer.Typer(help="Manage SageMaker training jobs") + +_STATUS_COLOR = { + "Completed": "green", + "Failed": "red", + "InProgress": "yellow", + "Stopping": "yellow", + "Stopped": "dim", +} + + +def _config_dir(config_path: str) -> str: + from pathlib import Path + return str(Path(config_path).parent) + + +@app.command() +def start(config: str = CONFIG_OPT) -> None: + """Submit a SageMaker training job.""" + cfg = load_cfg(config) + + if not cfg.sagemaker.training.image_uri: + CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]") + CONSOLE.print( + "Find pre-built images at: " + "https://aws.github.io/deep-learning-containers/reference/available_images" + ) + raise typer.Exit(1) + + role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name) + if not role_arn: + CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]") + raise typer.Exit(1) + + job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}" + s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}" + + CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...") + training_job = sm_ops.TrainingJobRequest( + role_arn=role_arn, + image_uri=cfg.sagemaker.training.image_uri, + instance_type=cfg.sagemaker.training.instance_type, + instance_count=cfg.sagemaker.training.instance_count, + s3_train_uri=s3_train_uri, + s3_output_path=s3_output, + job_name=job_name, + hyperparameters=cfg.sagemaker.training.hyperparameters, + entry_point=cfg.sagemaker.training.entry_point, + source_dir=cfg.sagemaker.training.source_dir, + ) + sm_ops.start_training_job(cfg.aws.boto3_session, training_job) + + state_ops.write_state(_config_dir(config), last_training_job=job_name) + + CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]") + CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") + + +@app.command() +def status( + job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), + config: str = CONFIG_OPT, +) -> None: + """Show training job status.""" + cfg = load_cfg(config) + + if not job_name: + job_name = state_ops.get_last_training_job(_config_dir(config)) + if not job_name: + CONSOLE.print( + "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" + ) + raise typer.Exit(1) + + status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + color = _STATUS_COLOR.get(status.status, "white") + + CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]") + CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]") + if status.created: + CONSOLE.print(f"Created: {status.created}") + if status.model_artifacts: + CONSOLE.print(f"Artifacts: {status.model_artifacts}") + if status.failure_reason: + CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") + + +@app.command(name="list") +def list_jobs( + limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"), + config: str = CONFIG_OPT, +) -> None: + """List recent training jobs.""" + cfg = load_cfg(config) + jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit) + + if not jobs: + CONSOLE.print("[yellow]No training jobs found.[/yellow]") + return + + table = Table(title="Training Jobs") + table.add_column("Name", style="cyan") + table.add_column("Status") + table.add_column("Created") + + for job in jobs: + status_value = str(job["TrainingJobStatus"]) + color = _STATUS_COLOR.get(status_value, "white") + table.add_row( + str(job["TrainingJobName"]), + f"[{color}]{status_value}[/{color}]", + str(job.get("CreationTime", "")), + ) + + CONSOLE.print(table) diff --git a/src/config.py b/src/config.py index de30b3f..833dddd 100644 --- a/src/config.py +++ b/src/config.py @@ -1,7 +1,8 @@ from enum import Enum -from typing import Literal +from typing import Any, Literal, TypedDict from mypy_boto3_s3.literals import BucketLocationConstraintType +from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType from pydantic import BaseModel, Field, model_validator @@ -17,10 +18,19 @@ class MlflowServerSize(str, Enum): large = "Large" +class Boto3SessionKwargs(TypedDict): + profile_name: str + region_name: str + + class AwsConfig(BaseModel): region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1" profile: str = "default" + @property + def boto3_session(self) -> Boto3SessionKwargs: + return {"profile_name": self.profile, "region_name": self.region} + class S3Config(BaseModel): bucket: str = "my-qc-mlops-bucket" @@ -28,8 +38,18 @@ class S3Config(BaseModel): model_prefix: str = "models/" +class TrainingConfig(BaseModel): + instance_type: TrainingInstanceTypeType = "ml.m5.xlarge" + instance_count: int = 1 + image_uri: str = "" + entry_point: str | None = None + source_dir: str | None = None + hyperparameters: dict[str, Any] = Field(default_factory=dict) + + class SageMakerConfig(BaseModel): role_name: str = "qc-cli-sagemaker-role" + training: TrainingConfig = Field(default_factory=TrainingConfig) class MlflowConfig(BaseModel): diff --git a/src/main.py b/src/main.py index 3414744..a2ecf5e 100644 --- a/src/main.py +++ b/src/main.py @@ -6,7 +6,7 @@ from rich.console import Console from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn from src.aws import s3 as s3_ops -from src.commands import infra +from src.commands import infra, train from src.commands.utils import CONFIG_OPT, load_cfg from src.config import Config @@ -15,6 +15,7 @@ app = typer.Typer( no_args_is_help=True, ) app.add_typer(infra.app, name="infra") +app.add_typer(train.app, name="train") console = Console() @@ -36,7 +37,10 @@ def init( yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False) console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]") - console.print("Edit it (especially [cyan]s3.bucket[/cyan]) before running other commands.") + console.print( + "Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) " + "before running other commands." + ) @app.command() diff --git a/src/state.py b/src/state.py new file mode 100644 index 0000000..3b18434 --- /dev/null +++ b/src/state.py @@ -0,0 +1,30 @@ +import json +from pathlib import Path +from typing import Any + +STATE_FILE = ".qc-cli.json" + + +def _path(config_dir: str) -> Path: + return Path(config_dir) / STATE_FILE + + +def read_state(config_dir: str = ".") -> dict[str, Any]: + path = _path(config_dir) + if not path.exists(): + return {} + with open(path) as f: + return json.load(f) + + +def write_state(config_dir: str = ".", **updates: str | None) -> None: + path = _path(config_dir) + state = read_state(config_dir) + state.update(updates) + with open(path, "w") as f: + json.dump(state, f, indent=2) + + +def get_last_training_job(config_dir: str = ".") -> str | None: + value = read_state(config_dir).get("last_training_job") + return str(value) if value else None