Skip to content

Commit 65cc09e

Browse files
Add test coverage for multi node multi gpu mnist training
1 parent 4102ee1 commit 65cc09e

File tree

7 files changed

+773
-6
lines changed

7 files changed

+773
-6
lines changed

tests/common/support/environment.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ const (
5151
pipTrustedHost = "PIP_TRUSTED_HOST"
5252

5353
// Storage bucket credentials
54-
storageDefaultEndpoint = "AWS_DEFAULT_ENDPOINT"
55-
storageDefaultRegion = "AWS_DEFAULT_REGION"
56-
storageAccessKeyId = "AWS_ACCESS_KEY_ID"
57-
storageSecretKey = "AWS_SECRET_ACCESS_KEY"
58-
storageBucketName = "AWS_STORAGE_BUCKET"
59-
storageBucketMnistDir = "AWS_STORAGE_BUCKET_MNIST_DIR"
54+
storageDefaultEndpoint = "AWS_DEFAULT_ENDPOINT"
55+
storageDefaultRegion = "AWS_DEFAULT_REGION"
56+
storageAccessKeyId = "AWS_ACCESS_KEY_ID"
57+
storageSecretKey = "AWS_SECRET_ACCESS_KEY"
58+
storageBucketName = "AWS_STORAGE_BUCKET"
59+
storageBucketMnistDir = "AWS_STORAGE_BUCKET_MNIST_DIR"
60+
storageBucketFashionMnistDir = "AWS_STORAGE_BUCKET_FASHION_MNIST_DIR"
6061

6162
// Name of existing namespace to be used for test
6263
testNamespaceNameEnvVar = "TEST_NAMESPACE_NAME"
@@ -179,6 +180,11 @@ func GetStorageBucketMnistDir() (string, bool) {
179180
return storage_bucket_mnist_dir, exists
180181
}
181182

183+
func GetStorageBucketFashionMnistDir() (string, bool) {
184+
storage_bucket_fashion_mnist_dir, exists := os.LookupEnv(storageBucketFashionMnistDir)
185+
return storage_bucket_fashion_mnist_dir, exists
186+
}
187+
182188
func GetPipIndexURL() string {
183189
return lookupEnvOrDefault(pipIndexURL, "https://pypi.python.org/simple")
184190
}

tests/common/support/jobset.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ func JobSetConditionFailed(jobset *jobsetv1alpha2.JobSet) metav1.ConditionStatus
6868
return JobSetCondition(jobset, jobsetv1alpha2.JobSetFailed)
6969
}
7070

71+
func JobSetConditionCompleted(jobset *jobsetv1alpha2.JobSet) metav1.ConditionStatus {
72+
return JobSetCondition(jobset, jobsetv1alpha2.JobSetCompleted)
73+
}
74+
7175
func JobSetFailureMessage(jobset *jobsetv1alpha2.JobSet) string {
7276
if jobset == nil {
7377
return ""
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os, gzip, shutil
2+
from minio import Minio
3+
from torchvision import datasets
4+
from torchvision.transforms import Compose, ToTensor
5+
6+
def main(dataset_path):
7+
# Download and Load Fashion-MNIST dataset
8+
if all(var in os.environ for var in ["AWS_DEFAULT_ENDPOINT","AWS_ACCESS_KEY_ID","AWS_SECRET_ACCESS_KEY","AWS_STORAGE_BUCKET","AWS_STORAGE_BUCKET_FASHION_MNIST_DIR"]):
9+
print("Using provided storage bucket to download Fashion-MNIST datasets...")
10+
dataset_dir = os.path.join(dataset_path, "FashionMNIST/raw")
11+
endpoint = os.environ.get("AWS_DEFAULT_ENDPOINT")
12+
access_key = os.environ.get("AWS_ACCESS_KEY_ID")
13+
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
14+
bucket_name = os.environ.get("AWS_STORAGE_BUCKET")
15+
print(f"Storage bucket endpoint: {endpoint}")
16+
print(f"Storage bucket name: {bucket_name}\n")
17+
18+
# remove prefix if specified in storage bucket endpoint url
19+
secure = True
20+
if endpoint.startswith("https://"):
21+
endpoint = endpoint[len("https://") :]
22+
elif endpoint.startswith("http://"):
23+
endpoint = endpoint[len("http://") :]
24+
secure = False
25+
26+
client = Minio(
27+
endpoint,
28+
access_key=access_key,
29+
secret_key=secret_key,
30+
cert_check=False,
31+
secure=secure
32+
)
33+
if not os.path.exists(dataset_dir):
34+
os.makedirs(dataset_dir)
35+
else:
36+
print(f"Directory '{dataset_dir}' already exists")
37+
38+
# To download datasets from storage bucket's specific directory, use prefix to provide directory name
39+
prefix=os.environ.get("AWS_STORAGE_BUCKET_FASHION_MNIST_DIR")
40+
print(f"Storage bucket Fashion-MNIST directory prefix: {prefix}\n")
41+
42+
# download all files from prefix folder of storage bucket recursively
43+
for item in client.list_objects(
44+
bucket_name, prefix=prefix, recursive=True
45+
):
46+
file_name=item.object_name[len(prefix)+1:]
47+
dataset_file_path = os.path.join(dataset_dir, file_name)
48+
print(f"Downloading dataset file {file_name} to {dataset_file_path}..")
49+
if not os.path.exists(dataset_file_path):
50+
client.fget_object(
51+
bucket_name, item.object_name, dataset_file_path
52+
)
53+
# Unzip files --
54+
## Sample zipfilepath : ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
55+
with gzip.open(dataset_file_path, "rb") as f_in:
56+
filename=file_name.split(".")[0] #-> t10k-images-idx3-ubyte
57+
file_path=("/".join(dataset_file_path.split("/")[:-1])) #->../data/FashionMNIST/raw
58+
full_file_path=os.path.join(file_path,filename) #->../data/FashionMNIST/raw/t10k-images-idx3-ubyte
59+
print(f"Extracting {dataset_file_path} to {file_path}..")
60+
61+
with open(full_file_path, "wb") as f_out:
62+
shutil.copyfileobj(f_in, f_out)
63+
print(f"Dataset file downloaded : {full_file_path}\n")
64+
# delete zip file
65+
os.remove(dataset_file_path)
66+
else:
67+
print(f"File-path '{dataset_file_path}' already exists")
68+
download_datasets = False
69+
else:
70+
print("Using default Fashion-MNIST mirror references to download datasets ...")
71+
print("Skipped usage of S3 storage bucket, because required environment variables aren't provided!\nRequired environment variables : AWS_DEFAULT_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_STORAGE_BUCKET, AWS_STORAGE_BUCKET_FASHION_MNIST_DIR")
72+
download_datasets = True
73+
74+
datasets.FashionMNIST(
75+
dataset_path,
76+
train=True,
77+
download=download_datasets,
78+
transform=Compose([ToTensor()])
79+
)
80+
81+
if __name__ == "__main__":
82+
import argparse
83+
parser = argparse.ArgumentParser(description="Fashion-MNIST dataset download")
84+
parser.add_argument('--dataset_path', type=str, default="./data", help='Path to Fashion-MNIST datasets (default: ./data)')
85+
86+
args = parser.parse_args()
87+
88+
main(
89+
dataset_path=args.dataset_path,
90+
)
91+
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
def train_pytorch():
2+
import os
3+
import logging
4+
5+
import torch
6+
from torch import nn
7+
import torch.nn.functional as F
8+
9+
from torchvision import datasets, transforms
10+
import torch.distributed as dist
11+
from torch.utils.data import DataLoader, DistributedSampler
12+
13+
# Configure logger (similar to KFTO mnist.py)
14+
log_formatter = logging.Formatter(
15+
"%(asctime)s %(levelname)-8s %(message)s", "%Y-%m-%dT%H:%M:%SZ"
16+
)
17+
logger = logging.getLogger(__file__)
18+
console_handler = logging.StreamHandler()
19+
console_handler.setFormatter(log_formatter)
20+
logger.addHandler(console_handler)
21+
logger.setLevel(logging.INFO)
22+
23+
# [1] Configure CPU/GPU device and distributed backend.
24+
# Kubeflow Trainer will automatically configure the distributed environment.
25+
device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")
26+
dist.init_process_group(backend=backend)
27+
28+
local_rank = int(os.getenv("LOCAL_RANK", 0))
29+
logger.info(
30+
"Distributed Training with WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}.".format(
31+
dist.get_world_size(),
32+
dist.get_rank(),
33+
local_rank,
34+
)
35+
)
36+
37+
# [2] Define PyTorch CNN Model to be trained.
38+
class Net(nn.Module):
39+
def __init__(self):
40+
super(Net, self).__init__()
41+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
42+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
43+
self.fc1 = nn.Linear(9216, 128)
44+
self.fc2 = nn.Linear(128, 10)
45+
46+
def forward(self, x):
47+
x = F.relu(self.conv1(x))
48+
x = F.relu(self.conv2(x))
49+
x = F.max_pool2d(x, 2)
50+
x = x.view(-1, 9216)
51+
x = F.relu(self.fc1(x))
52+
x = self.fc2(x)
53+
return F.log_softmax(x, dim=1)
54+
55+
# [3] Attach model to the correct device.
56+
device = torch.device(f"{device}:{local_rank}")
57+
model = nn.parallel.DistributedDataParallel(Net().to(device))
58+
model.train()
59+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
60+
61+
# [4] Get the Fashion-MNIST dataset.
62+
# Dataset should be pre-downloaded to avoid network dependencies.
63+
dataset_path = os.getenv("DATASET_PATH", "./data")
64+
65+
# Load dataset (download=False assumes dataset is already present)
66+
dataset = datasets.FashionMNIST(
67+
dataset_path,
68+
train=True,
69+
download=False,
70+
transform=transforms.Compose([transforms.ToTensor()]),
71+
)
72+
# Batch size configurable via env var (smaller = more iterations = longer training)
73+
batch_size = int(os.getenv("BATCH_SIZE", "64"))
74+
train_loader = DataLoader(
75+
dataset,
76+
batch_size=batch_size,
77+
sampler=DistributedSampler(dataset),
78+
)
79+
80+
# [5] Define the training loop.
81+
num_epochs = int(os.getenv("NUM_EPOCHS", "1"))
82+
global_rank = dist.get_rank()
83+
world_size = dist.get_world_size()
84+
85+
for epoch in range(num_epochs):
86+
# Log epoch start from ALL ranks
87+
num_batches = len(train_loader)
88+
device_type = "GPU" if torch.cuda.is_available() else "CPU"
89+
logger.info(f"[{device_type}{global_rank}] Epoch {epoch} | Batchsize: {batch_size} | Steps: {num_batches} | World Size: {world_size}")
90+
91+
# Set epoch for DistributedSampler to ensure proper shuffling
92+
if isinstance(train_loader.sampler, DistributedSampler):
93+
train_loader.sampler.set_epoch(epoch)
94+
95+
epoch_loss = 0.0
96+
num_batches_processed = 0
97+
98+
for batch_idx, (inputs, labels) in enumerate(train_loader):
99+
# Attach tensors to the device.
100+
inputs, labels = inputs.to(device), labels.to(device)
101+
102+
# Forward pass
103+
outputs = model(inputs)
104+
loss = F.nll_loss(outputs, labels)
105+
106+
# Backward pass
107+
optimizer.zero_grad()
108+
loss.backward()
109+
optimizer.step()
110+
111+
# Track loss for epoch summary
112+
epoch_loss += loss.item()
113+
num_batches_processed += 1
114+
115+
# Log detailed training progress from rank 0 only (to avoid log spam)
116+
if batch_idx % 10 == 0 and global_rank == 0:
117+
logger.info(
118+
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
119+
epoch,
120+
batch_idx * len(inputs) * world_size, # Adjust for distributed training
121+
len(train_loader.dataset),
122+
100.0 * batch_idx / num_batches,
123+
loss.item(),
124+
)
125+
)
126+
127+
# End-of-epoch summary from ALL ranks
128+
avg_loss = epoch_loss / num_batches_processed
129+
logger.info(f"[{device_type}{global_rank}] Epoch {epoch} completed | Avg Loss: {avg_loss:.6f} | Batches: {num_batches_processed}")
130+
131+
# Wait for the training to complete and destroy to PyTorch distributed process group.
132+
dist.barrier()
133+
# All ranks report completion
134+
logger.info(f"[{device_type}{global_rank}] Training is finished")
135+
dist.destroy_process_group()
136+
137+
138+
if __name__ == "__main__":
139+
train_pytorch()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Install only packages which are not present in base image
2+
# Use "# no-deps" marker for packages that should be installed without dependencies
3+
4+
minio==7.2.13
5+
torchvision==0.23.0 # no-deps
6+
pillow==11.0.0 # no-deps

tests/trainer/support.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
Copyright 2025.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package trainer
18+
19+
import (
20+
"embed"
21+
22+
"github.com/onsi/gomega"
23+
24+
"github.com/opendatahub-io/distributed-workloads/tests/common/support"
25+
)
26+
27+
//go:embed resources/*
28+
var files embed.FS
29+
30+
func readFile(t support.Test, fileName string) []byte {
31+
t.T().Helper()
32+
file, err := files.ReadFile(fileName)
33+
t.Expect(err).NotTo(gomega.HaveOccurred())
34+
return file
35+
}

0 commit comments

Comments
 (0)