189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""SageMaker entry point for CPU image-classification training."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, Subset, random_split
|
|
from torchvision import datasets, transforms
|
|
|
|
|
|
class SmallImageClassifier(nn.Module):
|
|
def __init__(self, class_count: int) -> None:
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2),
|
|
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2),
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2),
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
)
|
|
self.classifier = nn.Linear(64, class_count)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.features(x)
|
|
x = torch.flatten(x, 1)
|
|
return self.classifier(x)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--epochs", type=int, default=1)
|
|
parser.add_argument("--batch-size", type=int, default=32)
|
|
parser.add_argument("--learning-rate", type=float, default=0.001)
|
|
parser.add_argument("--image-size", type=int, default=160)
|
|
parser.add_argument("--validation-split", type=float, default=0.2)
|
|
parser.add_argument("--max-samples", type=int, default=0)
|
|
parser.add_argument("--seed", type=int, default=13)
|
|
parser.add_argument("--num-workers", type=int, default=2)
|
|
parser.add_argument("--train-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
return parser.parse_args()
|
|
|
|
|
|
def build_datasets(args: argparse.Namespace) -> tuple[Subset, Subset, dict[str, int]]:
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Resize((args.image_size, args.image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
]
|
|
)
|
|
dataset = datasets.ImageFolder(args.train_dir, transform=transform)
|
|
if len(dataset.classes) < 2:
|
|
raise ValueError(f"Expected at least two classes in {args.train_dir}. Found: {dataset.classes}")
|
|
|
|
if args.max_samples > 0 and args.max_samples < len(dataset):
|
|
indices = list(range(len(dataset)))
|
|
random.Random(args.seed).shuffle(indices)
|
|
dataset = Subset(dataset, indices[: args.max_samples])
|
|
|
|
validation_size = max(1, int(len(dataset) * args.validation_split))
|
|
train_size = len(dataset) - validation_size
|
|
if train_size < 1:
|
|
raise ValueError("Not enough images to create a train/validation split.")
|
|
|
|
generator = torch.Generator().manual_seed(args.seed)
|
|
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=generator)
|
|
return train_dataset, validation_dataset, getattr(dataset, "dataset", dataset).class_to_idx
|
|
|
|
|
|
def run_epoch(
|
|
model: nn.Module,
|
|
data_loader: DataLoader,
|
|
criterion: nn.Module,
|
|
optimizer: torch.optim.Optimizer | None,
|
|
device: torch.device,
|
|
) -> tuple[float, float]:
|
|
training = optimizer is not None
|
|
model.train(training)
|
|
|
|
total_loss = 0.0
|
|
total_correct = 0
|
|
total_examples = 0
|
|
|
|
for images, labels in data_loader:
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
|
|
with torch.set_grad_enabled(training):
|
|
logits = model(images)
|
|
loss = criterion(logits, labels)
|
|
|
|
if training:
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item() * images.size(0)
|
|
total_correct += (logits.argmax(dim=1) == labels).sum().item()
|
|
total_examples += images.size(0)
|
|
|
|
return total_loss / total_examples, total_correct / total_examples
|
|
|
|
|
|
def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None:
|
|
model.eval()
|
|
dummy_input = torch.randn(1, 3, image_size, image_size)
|
|
torch.onnx.export(
|
|
model,
|
|
dummy_input,
|
|
model_dir / "model.onnx",
|
|
export_params=True,
|
|
opset_version=17,
|
|
do_constant_folding=True,
|
|
input_names=["input"],
|
|
output_names=["logits"],
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
|
|
train_dataset, validation_dataset, class_to_idx = build_datasets(args)
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
)
|
|
validation_loader = DataLoader(
|
|
validation_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = SmallImageClassifier(class_count=len(class_to_idx)).to(device)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
|
|
|
print(f"Training on {device}. Classes: {sorted(class_to_idx)}")
|
|
metrics = []
|
|
for epoch in range(1, args.epochs + 1):
|
|
train_loss, train_accuracy = run_epoch(model, train_loader, criterion, optimizer, device)
|
|
validation_loss, validation_accuracy = run_epoch(model, validation_loader, criterion, None, device)
|
|
epoch_metrics = {
|
|
"epoch": epoch,
|
|
"train_loss": train_loss,
|
|
"train_accuracy": train_accuracy,
|
|
"validation_loss": validation_loss,
|
|
"validation_accuracy": validation_accuracy,
|
|
}
|
|
metrics.append(epoch_metrics)
|
|
print(json.dumps(epoch_metrics, sort_keys=True))
|
|
|
|
model_dir = Path(args.model_dir)
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
torch.save(
|
|
{
|
|
"model_state_dict": model.cpu().state_dict(),
|
|
"class_to_idx": class_to_idx,
|
|
"image_size": args.image_size,
|
|
},
|
|
model_dir / "model.pt",
|
|
)
|
|
export_onnx(model, model_dir, args.image_size)
|
|
(model_dir / "class_to_idx.json").write_text(json.dumps(class_to_idx, indent=2), encoding="utf-8")
|
|
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
|
print(f"Saved model artifacts to {model_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|