#!/usr/bin/env python3 """SageMaker entry point for CPU image-classification training.""" from __future__ import annotations import argparse import json import os import random from pathlib import Path import torch from torch import nn from torch.utils.data import DataLoader, Subset, random_split from torchvision import datasets, transforms class SmallImageClassifier(nn.Module): def __init__(self, class_count: int) -> None: super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d((1, 1)), ) self.classifier = nn.Linear(64, class_count) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = torch.flatten(x, 1) return self.classifier(x) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--learning-rate", type=float, default=0.001) parser.add_argument("--image-size", type=int, default=160) parser.add_argument("--validation-split", type=float, default=0.2) parser.add_argument("--max-samples", type=int, default=0) parser.add_argument("--seed", type=int, default=13) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument("--train-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train")) parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")) return parser.parse_args() def build_datasets(args: argparse.Namespace) -> tuple[Subset, Subset, dict[str, int]]: transform = transforms.Compose( [ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ) dataset = datasets.ImageFolder(args.train_dir, transform=transform) if len(dataset.classes) < 2: raise ValueError(f"Expected at least two classes in {args.train_dir}. Found: {dataset.classes}") if args.max_samples > 0 and args.max_samples < len(dataset): indices = list(range(len(dataset))) random.Random(args.seed).shuffle(indices) dataset = Subset(dataset, indices[: args.max_samples]) validation_size = max(1, int(len(dataset) * args.validation_split)) train_size = len(dataset) - validation_size if train_size < 1: raise ValueError("Not enough images to create a train/validation split.") generator = torch.Generator().manual_seed(args.seed) train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=generator) return train_dataset, validation_dataset, getattr(dataset, "dataset", dataset).class_to_idx def run_epoch( model: nn.Module, data_loader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer | None, device: torch.device, ) -> tuple[float, float]: training = optimizer is not None model.train(training) total_loss = 0.0 total_correct = 0 total_examples = 0 for images, labels in data_loader: images = images.to(device) labels = labels.to(device) with torch.set_grad_enabled(training): logits = model(images) loss = criterion(logits, labels) if training: optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) total_correct += (logits.argmax(dim=1) == labels).sum().item() total_examples += images.size(0) return total_loss / total_examples, total_correct / total_examples def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None: model.eval() dummy_input = torch.randn(1, 3, image_size, image_size) torch.onnx.export( model, dummy_input, model_dir / "model.onnx", export_params=True, opset_version=17, do_constant_folding=True, input_names=["input"], output_names=["logits"], dynamic_axes={ "input": {0: "batch_size"}, "logits": {0: "batch_size"}, }, ) def main() -> None: args = parse_args() random.seed(args.seed) torch.manual_seed(args.seed) train_dataset, validation_dataset, class_to_idx = build_datasets(args) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, ) validation_loader = DataLoader( validation_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SmallImageClassifier(class_count=len(class_to_idx)).to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) print(f"Training on {device}. Classes: {sorted(class_to_idx)}") metrics = [] for epoch in range(1, args.epochs + 1): train_loss, train_accuracy = run_epoch(model, train_loader, criterion, optimizer, device) validation_loss, validation_accuracy = run_epoch(model, validation_loader, criterion, None, device) epoch_metrics = { "epoch": epoch, "train_loss": train_loss, "train_accuracy": train_accuracy, "validation_loss": validation_loss, "validation_accuracy": validation_accuracy, } metrics.append(epoch_metrics) print(json.dumps(epoch_metrics, sort_keys=True)) model_dir = Path(args.model_dir) model_dir.mkdir(parents=True, exist_ok=True) torch.save( { "model_state_dict": model.cpu().state_dict(), "class_to_idx": class_to_idx, "image_size": args.image_size, }, model_dir / "model.pt", ) export_onnx(model, model_dir, args.image_size) (model_dir / "class_to_idx.json").write_text(json.dumps(class_to_idx, indent=2), encoding="utf-8") (model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8") print(f"Saved model artifacts to {model_dir}") if __name__ == "__main__": main()