Skip to content

Commit 6690be7

Browse files
committed
Add SageMaker example
1 parent 0b74797 commit 6690be7

File tree

4 files changed

+429
-0
lines changed

4 files changed

+429
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
FROM --platform=linux/amd64 nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04
2+
3+
# Install Python and pip
4+
RUN apt-get update && apt-get install -y --no-install-recommends \
5+
python3.9 \
6+
python3-pip \
7+
&& rm -rf /var/lib/apt/lists/*
8+
9+
# Create symbolic links for python and pip
10+
RUN ln -s /usr/bin/python3.9 /usr/bin/python
11+
12+
# Install Python dependencies
13+
COPY requirements.txt /opt/ml/code/requirements.txt
14+
RUN pip install --no-cache-dir -r /opt/ml/code/requirements.txt
15+
16+
# Set up environment variables for SageMaker
17+
ENV PATH="/opt/ml/code:${PATH}"
18+
19+
# Copy the training code into the container at /opt/ml/code
20+
COPY train /opt/ml/code/train
21+
22+
# Set the working directory to /opt/ml/code
23+
WORKDIR /opt/ml/code
24+
25+
# Empty the entrypoint and set permissions for the training script
26+
RUN chmod 777 /opt/ml/code/train
27+
ENTRYPOINT []
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch==2.5.1
2+
torchvision==0.20.1
3+
scikit-learn==1.5.2

aws-sagemaker-example/src/train

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import os
4+
import torch
5+
import torch.nn as nn
6+
import torch.optim as optim
7+
from torch.utils.data import DataLoader
8+
from torchvision import datasets, transforms, models
9+
import numpy as np
10+
from sklearn.metrics import roc_auc_score
11+
import json
12+
import logging
13+
14+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15+
logger = logging.getLogger()
16+
17+
class Autoencoder(nn.Module):
18+
def __init__(self):
19+
super(Autoencoder, self).__init__()
20+
# Encoder
21+
self.encoder = nn.Sequential(
22+
nn.Conv2d(3, 64, 4, stride=2, padding=1),
23+
nn.ReLU(),
24+
nn.Conv2d(64, 128, 4, stride=2, padding=1),
25+
nn.ReLU(),
26+
nn.Conv2d(128, 256, 4, stride=2, padding=1),
27+
nn.ReLU(),
28+
nn.Conv2d(256, 1024, 4, stride=2, padding=1), # Change 512 => 1024
29+
nn.ReLU(),
30+
)
31+
# Decoder
32+
self.decoder = nn.Sequential(
33+
nn.ConvTranspose2d(1024, 256, 4, stride=2, padding=1), # Change 512 => 1024
34+
nn.ReLU(),
35+
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
36+
nn.ReLU(),
37+
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
38+
nn.ReLU(),
39+
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
40+
nn.Sigmoid(),
41+
)
42+
43+
def forward(self, x):
44+
x = self.encoder(x)
45+
x = self.decoder(x)
46+
return x
47+
48+
def main():
49+
with open('/opt/ml/input/config/hyperparameters.json') as json_file:
50+
hyperparameters = json.load(json_file)
51+
logger.info(hyperparameters)
52+
data_dir = "/opt/ml/input/data/training"
53+
model_dir = '/opt/ml/model'
54+
55+
# Set device
56+
logger.info(f'Cuda available: {torch.cuda.is_available()}')
57+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
58+
59+
# Define data directories
60+
train_dir = os.path.join(data_dir, 'train')
61+
test_dir = os.path.join(data_dir, 'test')
62+
63+
# Define transforms
64+
transform = transforms.Compose([
65+
transforms.Resize((128, 128)),
66+
transforms.ToTensor(),
67+
])
68+
69+
# Create datasets and loaders
70+
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
71+
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
72+
73+
train_loader = DataLoader(train_dataset, batch_size=int(hyperparameters["batch-size"]), shuffle=True)
74+
test_loader = DataLoader(test_dataset, batch_size=int(hyperparameters["batch-size"]), shuffle=False)
75+
76+
logger.info(f'Training dataset size: {len(train_dataset)}, Testing dataset size: {len(test_dataset)}')
77+
logger.info(f'Batch size: {hyperparameters["batch-size"]}, Epochs: {hyperparameters["epochs"]}, Learning rate: {hyperparameters["learning-rate"]}')
78+
79+
# Initialize model, criterion, optimizer
80+
model = Autoencoder().to(device)
81+
criterion = nn.MSELoss()
82+
optimizer = optim.Adam(model.parameters(), lr=float(hyperparameters["learning-rate"]))
83+
84+
# Training loop
85+
for epoch in range(int(hyperparameters["epochs"])):
86+
model.train()
87+
running_loss = 0.0
88+
for data, _ in train_loader:
89+
data = data.to(device)
90+
optimizer.zero_grad()
91+
outputs = model(data)
92+
loss = criterion(outputs, data)
93+
loss.backward()
94+
optimizer.step()
95+
running_loss += loss.item() * data.size(0)
96+
epoch_loss = running_loss / len(train_loader.dataset)
97+
logger.info(f'Epoch [{epoch+1}/{int(hyperparameters["epochs"])}], Loss: {epoch_loss:.6f}')
98+
99+
# Function to compute reconstruction errors and labels
100+
def compute_reconstruction_errors(loader):
101+
model.eval()
102+
errors = []
103+
labels = []
104+
with torch.no_grad():
105+
for data, label in loader:
106+
data = data.to(device)
107+
outputs = model(data)
108+
loss = torch.mean((outputs - data) ** 2, dim=[1,2,3])
109+
errors.extend(loss.cpu().numpy())
110+
labels.extend(label.cpu().numpy())
111+
return errors, labels
112+
113+
# Compute reconstruction errors and labels for test dataset
114+
errors, labels = compute_reconstruction_errors(test_loader)
115+
116+
logger.info(f'Sample reconstruction errors (first 10): {errors[:10]}')
117+
118+
# Map labels: 'good' class (1) to 0, 'bad' class (0) to 1
119+
labels = 1-np.array(labels)
120+
errors = np.array(errors)
121+
anomaly_labels = labels # Assuming 'bad' images are labeled as 1
122+
anomaly_score = errors
123+
124+
# Compute ROC AUC
125+
auc = roc_auc_score(anomaly_labels, anomaly_score)
126+
logger.info(f'ROC AUC: {auc:.4f}')
127+
128+
# Save the trained model
129+
model_path = os.path.join(model_dir, 'model.pth')
130+
torch.save(model.state_dict(), model_path)
131+
logger.info(f'Model saved to {model_path}')
132+
133+
if __name__ == '__main__':
134+
main()

0 commit comments

Comments
 (0)