command to start sagemaker training
include sample training
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -220,4 +220,5 @@ __marimo__/
|
|||||||
.venv/
|
.venv/
|
||||||
config.yaml
|
config.yaml
|
||||||
cdk.out/
|
cdk.out/
|
||||||
.qc-cli-infra*
|
.qc-cli*.json
|
||||||
|
examples/*/data/
|
||||||
|
|||||||
30
README.md
30
README.md
@@ -30,11 +30,16 @@ qc-cli --help
|
|||||||
# 1. Create config.yaml in the current directory
|
# 1. Create config.yaml in the current directory
|
||||||
qc-cli init
|
qc-cli init
|
||||||
|
|
||||||
# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.role_name
|
# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.training.image_uri
|
||||||
|
|
||||||
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
|
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
|
||||||
# This is the step that requires the AWS CDK CLI.
|
# This is the step that requires the AWS CDK CLI.
|
||||||
qc-cli infra setup
|
qc-cli infra setup
|
||||||
|
|
||||||
|
# 4. Upload training data, then submit a SageMaker training job.
|
||||||
|
qc-cli upload ./my-dataset
|
||||||
|
qc-cli train start
|
||||||
|
qc-cli train status
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
@@ -51,8 +56,17 @@ s3:
|
|||||||
|
|
||||||
sagemaker:
|
sagemaker:
|
||||||
role_name: qc-cli-sagemaker-role
|
role_name: qc-cli-sagemaker-role
|
||||||
|
training:
|
||||||
|
image_uri: "" # ECR URI for your training container
|
||||||
|
instance_type: ml.m5.xlarge
|
||||||
|
instance_count: 1
|
||||||
|
entry_point: null # Optional: script inside source_dir
|
||||||
|
source_dir: null # Optional: local dir packaged and uploaded automatically
|
||||||
|
hyperparameters: {}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`hyperparameters` is a flat map of values passed to the training container. Valid keys depend on the selected training image and entry point.
|
||||||
|
|
||||||
To provision an MLflow tracking server, set:
|
To provision an MLflow tracking server, set:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -101,6 +115,19 @@ qc-cli upload <file> --s3-key <key> Upload a file to a custom S3 key
|
|||||||
|
|
||||||
Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads default to `s3://<bucket>/<data_prefix>/<filename>`. Directory uploads are recursive, preserve paths relative to the uploaded directory, and place files under `s3://<bucket>/<data_prefix>/`.
|
Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads default to `s3://<bucket>/<data_prefix>/<filename>`. Directory uploads are recursive, preserve paths relative to the uploaded directory, and place files under `s3://<bucket>/<data_prefix>/`.
|
||||||
|
|
||||||
|
### `train`
|
||||||
|
|
||||||
|
```
|
||||||
|
qc-cli train start Submit a SageMaker training job
|
||||||
|
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||||
|
qc-cli train list List recent training jobs
|
||||||
|
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||||
|
```
|
||||||
|
|
||||||
|
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||||
|
|
||||||
|
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||||
|
|
||||||
## AWS permissions required
|
## AWS permissions required
|
||||||
|
|
||||||
The IAM user or role running the CLI needs:
|
The IAM user or role running the CLI needs:
|
||||||
@@ -111,6 +138,7 @@ The IAM user or role running the CLI needs:
|
|||||||
| CreateRole, GetRole, DeleteRole, AttachRolePolicy, DetachRolePolicy | IAM |
|
| CreateRole, GetRole, DeleteRole, AttachRolePolicy, DetachRolePolicy | IAM |
|
||||||
| CreateStack, UpdateStack, DeleteStack, DescribeStacks, DescribeStackEvents | CloudFormation |
|
| CreateStack, UpdateStack, DeleteStack, DescribeStacks, DescribeStackEvents | CloudFormation |
|
||||||
| GetCallerIdentity | STS |
|
| GetCallerIdentity | STS |
|
||||||
|
| CreateTrainingJob, DescribeTrainingJob, ListTrainingJobs | SageMaker AI |
|
||||||
| CreateMlflowTrackingServer, DescribeMlflowTrackingServer, DeleteMlflowTrackingServer | SageMaker AI, when `mlflow.mode` is `create` or `existing` |
|
| CreateMlflowTrackingServer, DescribeMlflowTrackingServer, DeleteMlflowTrackingServer | SageMaker AI, when `mlflow.mode` is `create` or `existing` |
|
||||||
|
|
||||||
`AdministratorAccess` covers all of the above.
|
`AdministratorAccess` covers all of the above.
|
||||||
|
|||||||
90
examples/training/README.md
Normal file
90
examples/training/README.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# SageMaker Training Example
|
||||||
|
|
||||||
|
This example downloads a small image-classification dataset, uploads it through `qc-cli`, and submits a live SageMaker training job.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- AWS credentials configured for the profile in `config.yaml`
|
||||||
|
- Infrastructure already deployed with `qc-cli infra setup`
|
||||||
|
- `config.yaml` updated with:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
s3:
|
||||||
|
bucket: your-bucket-name
|
||||||
|
|
||||||
|
sagemaker:
|
||||||
|
role_name: <role-name>
|
||||||
|
training:
|
||||||
|
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||||
|
instance_type: ml.m4.xlarge
|
||||||
|
instance_count: 1
|
||||||
|
source_dir: examples/training/source
|
||||||
|
entry_point: train.py
|
||||||
|
hyperparameters:
|
||||||
|
epochs: 1
|
||||||
|
batch-size: 32
|
||||||
|
learning-rate: 0.001
|
||||||
|
image-size: 160
|
||||||
|
validation-split: 0.2
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Hyperparameters
|
||||||
|
|
||||||
|
Values under `sagemaker.training.hyperparameters` are passed to the training entry point as command-line arguments. For this example, they map to arguments defined in [source/train.py](source/train.py).
|
||||||
|
|
||||||
|
Supported by this example:
|
||||||
|
|
||||||
|
| Name | Type | Default | Description |
|
||||||
|
|---|---:|---:|---|
|
||||||
|
| `epochs` | int | `1` | Number of training epochs. |
|
||||||
|
| `batch-size` | int | `32` | Images per training batch. |
|
||||||
|
| `learning-rate` | float | `0.001` | Adam optimizer learning rate. |
|
||||||
|
| `image-size` | int | `160` | Resize images to square `image-size x image-size`. |
|
||||||
|
| `validation-split` | float | `0.2` | Fraction of data used for validation. |
|
||||||
|
| `max-samples` | int | `0` | Optional cap for smoke tests; `0` means use all images. |
|
||||||
|
| `seed` | int | `13` | Random seed for reproducible splitting. |
|
||||||
|
| `num-workers` | int | `2` | DataLoader worker count. |
|
||||||
|
|
||||||
|
Do not set `train-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.
|
||||||
|
|
||||||
|
## 1. Download The Dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/download_flower_photos.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/training/data/flower_photos_sagemaker/
|
||||||
|
daisy/
|
||||||
|
dandelion/
|
||||||
|
roses/
|
||||||
|
sunflowers/
|
||||||
|
tulips/
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Run Training
|
||||||
|
|
||||||
|
Run the training script and wait until it finishes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/run_training.sh --config config.yaml --wait
|
||||||
|
```
|
||||||
|
|
||||||
|
Use a dataset that is already uploaded to `s3.data_prefix`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/run_training.sh \
|
||||||
|
--config config.yaml \
|
||||||
|
--skip-upload \
|
||||||
|
--wait
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The default dataset path is `examples/training/data/flower_photos_sagemaker`.
|
||||||
|
- Uploaded data uses the `s3.bucket` and `s3.data_prefix` values from `config.yaml`.
|
||||||
|
- Training artifacts are written under `s3://<bucket>/<model_prefix>/`.
|
||||||
|
- The SageMaker `model.tar.gz` contains `model.onnx`, `model.pt`, `class_to_idx.json`, and `metrics.json`.
|
||||||
|
- SageMaker packages `examples/training/source`, installs `requirements.txt`, and runs `train.py`.
|
||||||
40
examples/training/download_flower_photos.sh
Executable file
40
examples/training/download_flower_photos.sh
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
DATASET_URL="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
|
||||||
|
DEST_DIR="${1:-examples/training/data}"
|
||||||
|
ARCHIVE_PATH="${DEST_DIR}/flower_photos.tgz"
|
||||||
|
RAW_DATASET_DIR="${DEST_DIR}/flower_photos"
|
||||||
|
DATASET_DIR="${DEST_DIR}/flower_photos_sagemaker"
|
||||||
|
CLASS_NAMES=("daisy" "dandelion" "roses" "sunflowers" "tulips")
|
||||||
|
|
||||||
|
mkdir -p "${DEST_DIR}"
|
||||||
|
|
||||||
|
if [[ -d "${DATASET_DIR}" ]]; then
|
||||||
|
echo "Dataset already exists: ${DATASET_DIR}"
|
||||||
|
echo "Use this path with run_training.py:"
|
||||||
|
echo " ${DATASET_DIR}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Downloading TensorFlow flower_photos dataset..."
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
curl -L "${DATASET_URL}" -o "${ARCHIVE_PATH}"
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
wget -O "${ARCHIVE_PATH}" "${DATASET_URL}"
|
||||||
|
else
|
||||||
|
echo "Either curl or wget is required." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Extracting dataset..."
|
||||||
|
tar -xzf "${ARCHIVE_PATH}" -C "${DEST_DIR}"
|
||||||
|
|
||||||
|
echo "Preparing SageMaker directory layout..."
|
||||||
|
mkdir -p "${DATASET_DIR}"
|
||||||
|
for class_name in "${CLASS_NAMES[@]}"; do
|
||||||
|
cp -R "${RAW_DATASET_DIR}/${class_name}" "${DATASET_DIR}/${class_name}"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Dataset ready: ${DATASET_DIR}"
|
||||||
|
find "${DATASET_DIR}" -mindepth 1 -maxdepth 1 -type d -print | sort
|
||||||
111
examples/training/run_training.sh
Executable file
111
examples/training/run_training.sh
Executable file
@@ -0,0 +1,111 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
CONFIG_PATH="config.yaml"
|
||||||
|
DATASET_DIR="examples/training/data/flower_photos_sagemaker"
|
||||||
|
WAIT=false
|
||||||
|
SKIP_UPLOAD=false
|
||||||
|
POLL_SECONDS=60
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat <<EOF
|
||||||
|
Usage: $0 [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--config PATH Path to qc-cli config file. Default: config.yaml
|
||||||
|
--dataset-dir PATH Dataset directory to upload. Default: ${DATASET_DIR}
|
||||||
|
--skip-upload Train against data already uploaded to s3.data_prefix.
|
||||||
|
--wait Poll until training completes.
|
||||||
|
-h, --help Show this help.
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--config)
|
||||||
|
CONFIG_PATH="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--dataset-dir)
|
||||||
|
DATASET_DIR="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--skip-upload)
|
||||||
|
SKIP_UPLOAD=true
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--wait)
|
||||||
|
WAIT=true
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1" >&2
|
||||||
|
usage >&2
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ ! -f "${CONFIG_PATH}" ]]; then
|
||||||
|
echo "Config not found: ${CONFIG_PATH}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${SKIP_UPLOAD}" == false && ! -d "${DATASET_DIR}" ]]; then
|
||||||
|
echo "Dataset not found: ${DATASET_DIR}" >&2
|
||||||
|
echo "Run: bash examples/training/download_flower_photos.sh" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo "+ $*"
|
||||||
|
"$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
run uv run qc-cli infra status --config "${CONFIG_PATH}"
|
||||||
|
|
||||||
|
if [[ "${SKIP_UPLOAD}" == false ]]; then
|
||||||
|
run uv run qc-cli upload "${DATASET_DIR}" --config "${CONFIG_PATH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
TRAIN_OUTPUT="$(uv run qc-cli train start --config "${CONFIG_PATH}")"
|
||||||
|
echo "${TRAIN_OUTPUT}"
|
||||||
|
|
||||||
|
JOB_NAME="$(printf '%s\n' "${TRAIN_OUTPUT}" | grep -Eo 'qc-cli-[0-9]{8}-[0-9]{6}' | tail -n 1)"
|
||||||
|
if [[ -z "${JOB_NAME}" ]]; then
|
||||||
|
echo "Could not find training job name in qc-cli output." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Submitted SageMaker training job: ${JOB_NAME}"
|
||||||
|
|
||||||
|
if [[ "${WAIT}" == false ]]; then
|
||||||
|
run uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
STATUS_OUTPUT="$(uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}")"
|
||||||
|
echo "${STATUS_OUTPUT}"
|
||||||
|
|
||||||
|
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Completed'; then
|
||||||
|
echo "Training completed successfully."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Failed'; then
|
||||||
|
echo "Training failed." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Stopped'; then
|
||||||
|
echo "Training stopped." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep "${POLL_SECONDS}"
|
||||||
|
done
|
||||||
1
examples/training/source/requirements.txt
Normal file
1
examples/training/source/requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
onnx==1.21.0
|
||||||
192
examples/training/source/train.py
Normal file
192
examples/training/source/train.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""SageMaker entry point for CPU image-classification training."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader, Subset, random_split
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
|
|
||||||
|
class SmallImageClassifier(nn.Module):
|
||||||
|
def __init__(self, class_count: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(64, class_count)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.features(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
return self.classifier(x)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--epochs", type=int, default=1)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=32)
|
||||||
|
parser.add_argument("--learning-rate", type=float, default=0.001)
|
||||||
|
parser.add_argument("--image-size", type=int, default=160)
|
||||||
|
parser.add_argument("--validation-split", type=float, default=0.2)
|
||||||
|
parser.add_argument("--max-samples", type=int, default=0)
|
||||||
|
parser.add_argument("--seed", type=int, default=13)
|
||||||
|
parser.add_argument("--num-workers", type=int, default=2)
|
||||||
|
parser.add_argument("--train-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
||||||
|
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_datasets(args: argparse.Namespace) -> tuple[Subset, Subset, dict[str, int]]:
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize((args.image_size, args.image_size)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset = datasets.ImageFolder(args.train_dir, transform=transform)
|
||||||
|
if len(dataset.classes) < 2:
|
||||||
|
raise ValueError(f"Expected at least two classes in {args.train_dir}. Found: {dataset.classes}")
|
||||||
|
|
||||||
|
if args.max_samples > 0 and args.max_samples < len(dataset):
|
||||||
|
indices = list(range(len(dataset)))
|
||||||
|
random.Random(args.seed).shuffle(indices)
|
||||||
|
dataset = Subset(dataset, indices[: args.max_samples])
|
||||||
|
|
||||||
|
validation_size = max(1, int(len(dataset) * args.validation_split))
|
||||||
|
train_size = len(dataset) - validation_size
|
||||||
|
if train_size < 1:
|
||||||
|
raise ValueError("Not enough images to create a train/validation split.")
|
||||||
|
|
||||||
|
generator = torch.Generator().manual_seed(args.seed)
|
||||||
|
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=generator)
|
||||||
|
return train_dataset, validation_dataset, getattr(dataset, "dataset", dataset).class_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
def run_epoch(
|
||||||
|
model: nn.Module,
|
||||||
|
data_loader: DataLoader,
|
||||||
|
criterion: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer | None,
|
||||||
|
device: torch.device,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
training = optimizer is not None
|
||||||
|
model.train(training)
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
total_correct = 0
|
||||||
|
total_examples = 0
|
||||||
|
|
||||||
|
for images, labels in data_loader:
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(training):
|
||||||
|
logits = model(images)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item() * images.size(0)
|
||||||
|
total_correct += (logits.argmax(dim=1) == labels).sum().item()
|
||||||
|
total_examples += images.size(0)
|
||||||
|
|
||||||
|
return total_loss / total_examples, total_correct / total_examples
|
||||||
|
|
||||||
|
|
||||||
|
def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None:
|
||||||
|
model.eval()
|
||||||
|
dummy_input = torch.randn(1, 3, image_size, image_size)
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy_input,
|
||||||
|
model_dir / "model.onnx",
|
||||||
|
export_params=True,
|
||||||
|
opset_version=17,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=["input"],
|
||||||
|
output_names=["logits"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input": {0: "batch_size"},
|
||||||
|
"logits": {0: "batch_size"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
train_dataset, validation_dataset, class_to_idx = build_datasets(args)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
)
|
||||||
|
validation_loader = DataLoader(
|
||||||
|
validation_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = SmallImageClassifier(class_count=len(class_to_idx)).to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
print(f"Training on {device}. Classes: {sorted(class_to_idx)}")
|
||||||
|
metrics = []
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
train_loss, train_accuracy = run_epoch(model, train_loader, criterion, optimizer, device)
|
||||||
|
validation_loss, validation_accuracy = run_epoch(model, validation_loader, criterion, None, device)
|
||||||
|
epoch_metrics = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"train_loss": train_loss,
|
||||||
|
"train_accuracy": train_accuracy,
|
||||||
|
"validation_loss": validation_loss,
|
||||||
|
"validation_accuracy": validation_accuracy,
|
||||||
|
}
|
||||||
|
metrics.append(epoch_metrics)
|
||||||
|
print(json.dumps(epoch_metrics, sort_keys=True))
|
||||||
|
|
||||||
|
model_dir = Path(args.model_dir)
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"model_state_dict": model.cpu().state_dict(),
|
||||||
|
"class_to_idx": class_to_idx,
|
||||||
|
"image_size": args.image_size,
|
||||||
|
},
|
||||||
|
model_dir / "model.pt",
|
||||||
|
)
|
||||||
|
export_onnx(model, model_dir, args.image_size)
|
||||||
|
(model_dir / "class_to_idx.json").write_text(json.dumps(class_to_idx, indent=2), encoding="utf-8")
|
||||||
|
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||||
|
print(f"Saved model artifacts to {model_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
17
src/aws/iam.py
Normal file
17
src/aws/iam.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from mypy_boto3_iam import IAMClient
|
||||||
|
|
||||||
|
|
||||||
|
def _client(profile: str) -> IAMClient:
|
||||||
|
return boto3.Session(profile_name=profile).client("iam")
|
||||||
|
|
||||||
|
|
||||||
|
def get_role_arn(profile: str, role_name: str) -> str | None:
|
||||||
|
client = _client(profile)
|
||||||
|
try:
|
||||||
|
return client.get_role(RoleName=role_name)["Role"]["Arn"]
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
||||||
|
return None
|
||||||
|
raise
|
||||||
131
src/aws/sagemaker.py
Normal file
131
src/aws/sagemaker.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from mypy_boto3_sagemaker import SageMakerClient
|
||||||
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||||
|
from mypy_boto3_sagemaker.type_defs import (
|
||||||
|
CreateTrainingJobRequestTypeDef,
|
||||||
|
ResourceConfigTypeDef,
|
||||||
|
TrainingJobSummaryTypeDef,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.config import Boto3SessionKwargs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingJobRequest:
|
||||||
|
role_arn: str
|
||||||
|
image_uri: str
|
||||||
|
instance_type: TrainingInstanceTypeType
|
||||||
|
instance_count: int
|
||||||
|
s3_train_uri: str
|
||||||
|
s3_output_path: str
|
||||||
|
job_name: str
|
||||||
|
hyperparameters: dict[str, Any] = field(default_factory=dict)
|
||||||
|
entry_point: str | None = None
|
||||||
|
source_dir: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingJobStatus:
|
||||||
|
name: str
|
||||||
|
status: str
|
||||||
|
created: datetime | None
|
||||||
|
modified: datetime | None
|
||||||
|
model_artifacts: str | None
|
||||||
|
failure_reason: str | None
|
||||||
|
|
||||||
|
|
||||||
|
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
|
||||||
|
return boto3.Session(**session).client("sagemaker")
|
||||||
|
|
||||||
|
|
||||||
|
def _upload_source_dir(
|
||||||
|
session: Boto3SessionKwargs,
|
||||||
|
source_dir: str,
|
||||||
|
s3_output_path: str,
|
||||||
|
job_name: str,
|
||||||
|
) -> str:
|
||||||
|
import io
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||||
|
tar.add(source_dir, arcname=".")
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
without_scheme = s3_output_path.removeprefix("s3://")
|
||||||
|
bucket, _, prefix = without_scheme.partition("/")
|
||||||
|
key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/")
|
||||||
|
|
||||||
|
boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key)
|
||||||
|
return f"s3://{bucket}/{key}"
|
||||||
|
|
||||||
|
|
||||||
|
def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str:
|
||||||
|
hp = {k: str(v) for k, v in job.hyperparameters.items()}
|
||||||
|
|
||||||
|
if job.source_dir:
|
||||||
|
s3_code_uri = _upload_source_dir(
|
||||||
|
session,
|
||||||
|
job.source_dir,
|
||||||
|
job.s3_output_path,
|
||||||
|
job.job_name,
|
||||||
|
)
|
||||||
|
hp["sagemaker_program"] = job.entry_point or "train.py"
|
||||||
|
hp["sagemaker_submit_directory"] = s3_code_uri
|
||||||
|
|
||||||
|
resource_config: ResourceConfigTypeDef = {
|
||||||
|
"InstanceType": job.instance_type,
|
||||||
|
"InstanceCount": job.instance_count,
|
||||||
|
"VolumeSizeInGB": 30,
|
||||||
|
}
|
||||||
|
request: CreateTrainingJobRequestTypeDef = {
|
||||||
|
"TrainingJobName": job.job_name,
|
||||||
|
"AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"},
|
||||||
|
"RoleArn": job.role_arn,
|
||||||
|
"InputDataConfig": [
|
||||||
|
{
|
||||||
|
"ChannelName": "train",
|
||||||
|
"DataSource": {
|
||||||
|
"S3DataSource": {
|
||||||
|
"S3DataType": "S3Prefix",
|
||||||
|
"S3Uri": job.s3_train_uri,
|
||||||
|
"S3DataDistributionType": "FullyReplicated",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"OutputDataConfig": {"S3OutputPath": job.s3_output_path},
|
||||||
|
"ResourceConfig": resource_config,
|
||||||
|
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
|
||||||
|
"HyperParameters": hp,
|
||||||
|
}
|
||||||
|
_sm(session).create_training_job(**request)
|
||||||
|
return job.job_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus:
|
||||||
|
resp = _sm(session).describe_training_job(TrainingJobName=job_name)
|
||||||
|
return TrainingJobStatus(
|
||||||
|
name=resp["TrainingJobName"],
|
||||||
|
status=resp["TrainingJobStatus"],
|
||||||
|
created=resp.get("CreationTime"),
|
||||||
|
modified=resp.get("LastModifiedTime"),
|
||||||
|
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
|
||||||
|
failure_reason=resp.get("FailureReason"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_training_jobs(
|
||||||
|
session: Boto3SessionKwargs,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> list[TrainingJobSummaryTypeDef]:
|
||||||
|
resp = _sm(session).list_training_jobs(
|
||||||
|
SortBy="CreationTime",
|
||||||
|
SortOrder="Descending",
|
||||||
|
MaxResults=max_results,
|
||||||
|
)
|
||||||
|
return list(resp["TrainingJobSummaries"])
|
||||||
126
src/commands/train.py
Normal file
126
src/commands/train.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
|
from src.aws import iam
|
||||||
|
from src.aws import sagemaker as sm_ops
|
||||||
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
|
|
||||||
|
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||||
|
|
||||||
|
_STATUS_COLOR = {
|
||||||
|
"Completed": "green",
|
||||||
|
"Failed": "red",
|
||||||
|
"InProgress": "yellow",
|
||||||
|
"Stopping": "yellow",
|
||||||
|
"Stopped": "dim",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _config_dir(config_path: str) -> str:
|
||||||
|
from pathlib import Path
|
||||||
|
return str(Path(config_path).parent)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def start(config: str = CONFIG_OPT) -> None:
|
||||||
|
"""Submit a SageMaker training job."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
if not cfg.sagemaker.training.image_uri:
|
||||||
|
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||||
|
CONSOLE.print(
|
||||||
|
"Find pre-built images at: "
|
||||||
|
"https://aws.github.io/deep-learning-containers/reference/available_images"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||||
|
if not role_arn:
|
||||||
|
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||||
|
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
|
||||||
|
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
|
||||||
|
|
||||||
|
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
|
||||||
|
training_job = sm_ops.TrainingJobRequest(
|
||||||
|
role_arn=role_arn,
|
||||||
|
image_uri=cfg.sagemaker.training.image_uri,
|
||||||
|
instance_type=cfg.sagemaker.training.instance_type,
|
||||||
|
instance_count=cfg.sagemaker.training.instance_count,
|
||||||
|
s3_train_uri=s3_train_uri,
|
||||||
|
s3_output_path=s3_output,
|
||||||
|
job_name=job_name,
|
||||||
|
hyperparameters=cfg.sagemaker.training.hyperparameters,
|
||||||
|
entry_point=cfg.sagemaker.training.entry_point,
|
||||||
|
source_dir=cfg.sagemaker.training.source_dir,
|
||||||
|
)
|
||||||
|
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
|
||||||
|
|
||||||
|
state_ops.write_state(_config_dir(config), last_training_job=job_name)
|
||||||
|
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||||
|
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def status(
|
||||||
|
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Show training job status."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
if not job_name:
|
||||||
|
job_name = state_ops.get_last_training_job(_config_dir(config))
|
||||||
|
if not job_name:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
|
color = _STATUS_COLOR.get(status.status, "white")
|
||||||
|
|
||||||
|
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||||
|
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||||
|
if status.created:
|
||||||
|
CONSOLE.print(f"Created: {status.created}")
|
||||||
|
if status.model_artifacts:
|
||||||
|
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||||
|
if status.failure_reason:
|
||||||
|
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(name="list")
|
||||||
|
def list_jobs(
|
||||||
|
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""List recent training jobs."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
|
||||||
|
|
||||||
|
if not jobs:
|
||||||
|
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Training Jobs")
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Status")
|
||||||
|
table.add_column("Created")
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
status_value = str(job["TrainingJobStatus"])
|
||||||
|
color = _STATUS_COLOR.get(status_value, "white")
|
||||||
|
table.add_row(
|
||||||
|
str(job["TrainingJobName"]),
|
||||||
|
f"[{color}]{status_value}[/{color}]",
|
||||||
|
str(job.get("CreationTime", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
CONSOLE.print(table)
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||||
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
@@ -17,10 +18,19 @@ class MlflowServerSize(str, Enum):
|
|||||||
large = "Large"
|
large = "Large"
|
||||||
|
|
||||||
|
|
||||||
|
class Boto3SessionKwargs(TypedDict):
|
||||||
|
profile_name: str
|
||||||
|
region_name: str
|
||||||
|
|
||||||
|
|
||||||
class AwsConfig(BaseModel):
|
class AwsConfig(BaseModel):
|
||||||
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
||||||
profile: str = "default"
|
profile: str = "default"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def boto3_session(self) -> Boto3SessionKwargs:
|
||||||
|
return {"profile_name": self.profile, "region_name": self.region}
|
||||||
|
|
||||||
|
|
||||||
class S3Config(BaseModel):
|
class S3Config(BaseModel):
|
||||||
bucket: str = "my-qc-mlops-bucket"
|
bucket: str = "my-qc-mlops-bucket"
|
||||||
@@ -28,8 +38,18 @@ class S3Config(BaseModel):
|
|||||||
model_prefix: str = "models/"
|
model_prefix: str = "models/"
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
|
||||||
|
instance_count: int = 1
|
||||||
|
image_uri: str = ""
|
||||||
|
entry_point: str | None = None
|
||||||
|
source_dir: str | None = None
|
||||||
|
hyperparameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class SageMakerConfig(BaseModel):
|
class SageMakerConfig(BaseModel):
|
||||||
role_name: str = "qc-cli-sagemaker-role"
|
role_name: str = "qc-cli-sagemaker-role"
|
||||||
|
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
|
||||||
|
|
||||||
class MlflowConfig(BaseModel):
|
class MlflowConfig(BaseModel):
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from rich.console import Console
|
|||||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||||
|
|
||||||
from src.aws import s3 as s3_ops
|
from src.aws import s3 as s3_ops
|
||||||
from src.commands import infra
|
from src.commands import infra, train
|
||||||
from src.commands.utils import CONFIG_OPT, load_cfg
|
from src.commands.utils import CONFIG_OPT, load_cfg
|
||||||
from src.config import Config
|
from src.config import Config
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ app = typer.Typer(
|
|||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
app.add_typer(infra.app, name="infra")
|
app.add_typer(infra.app, name="infra")
|
||||||
|
app.add_typer(train.app, name="train")
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -36,7 +37,10 @@ def init(
|
|||||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||||
console.print("Edit it (especially [cyan]s3.bucket[/cyan]) before running other commands.")
|
console.print(
|
||||||
|
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
||||||
|
"before running other commands."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|||||||
30
src/state.py
Normal file
30
src/state.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
STATE_FILE = ".qc-cli.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _path(config_dir: str) -> Path:
|
||||||
|
return Path(config_dir) / STATE_FILE
|
||||||
|
|
||||||
|
|
||||||
|
def read_state(config_dir: str = ".") -> dict[str, Any]:
|
||||||
|
path = _path(config_dir)
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_state(config_dir: str = ".", **updates: str | None) -> None:
|
||||||
|
path = _path(config_dir)
|
||||||
|
state = read_state(config_dir)
|
||||||
|
state.update(updates)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(state, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_training_job(config_dir: str = ".") -> str | None:
|
||||||
|
value = read_state(config_dir).get("last_training_job")
|
||||||
|
return str(value) if value else None
|
||||||
Reference in New Issue
Block a user