Skip to content

Commit 1fc5993

Browse files
working test with CPUs
Signed-off-by: Kevin <[email protected]>
1 parent 9be537a commit 1fc5993

File tree

3 files changed

+44
-355
lines changed

3 files changed

+44
-355
lines changed

tests/kfto/kfto_mnist_sdk_test.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package kfto
1919
import (
2020
"strings"
2121
"testing"
22-
"time"
2322

2423
. "github.com/onsi/gomega"
2524
. "github.com/project-codeflare/codeflare-common/support"
@@ -29,33 +28,33 @@ import (
2928

3029
func TestMnistSDK(t *testing.T) {
3130
test := With(t)
32-
3331
// Create a namespace
3432
namespace := test.NewTestNamespace()
3533
userName := GetNotebookUserName(test)
3634
userToken := GetNotebookUserToken(test)
3735
jupyterNotebookConfigMapFileName := "mnist_kfto.ipynb"
38-
mnist := readMnistScriptTemplate(test, "resources/kfto_sdk_train.py")
36+
mnist := readMnistScriptTemplate(test, "resources/kfto_sdk_mnist.py")
3937

4038
// Create role binding with Namespace specific admin cluster role
4139
CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin")
4240

4341
requiredChangesInNotebook := map[string]string{
44-
"${api_url}": GetOpenShiftApiUrl(test),
45-
"${train_function}": "train_func_2",
46-
"${password}": userToken,
47-
"${num_gpus}": "2",
48-
"${namespace}": namespace.Name,
42+
"${api_url}": GetOpenShiftApiUrl(test),
43+
"${password}": userToken,
44+
"${num_gpus}": "0",
45+
"${namespace}": namespace.Name,
4946
}
5047

5148
jupyterNotebook := string(ReadFile(test, "resources/mnist_kfto.ipynb"))
49+
requirements := ReadFile(test, "resources/requirements.txt")
5250
for oldValue, newValue := range requiredChangesInNotebook {
5351
jupyterNotebook = strings.Replace(string(jupyterNotebook), oldValue, newValue, -1)
5452
}
5553

5654
config := CreateConfigMap(test, namespace.Name, map[string][]byte{
5755
jupyterNotebookConfigMapFileName: []byte(jupyterNotebook),
5856
"kfto_sdk_mnist.py": mnist,
57+
"requirements.txt": requirements,
5958
})
6059

6160
// Create Notebook CR
@@ -68,15 +67,15 @@ func TestMnistSDK(t *testing.T) {
6867
}()
6968

7069
// Make sure pytorch job is created
71-
Eventually(PyTorchJob(test, namespace.Name, "pytorch-ddp")).
70+
test.Eventually(PyTorchJob(test, namespace.Name, "pytorch-ddp")).
7271
Should(WithTransform(PyTorchJobConditionRunning, Equal(v1.ConditionTrue)))
7372

7473
// Make sure that the job eventually succeeds
75-
Eventually(PyTorchJob(test, namespace.Name, "pytorch-ddp")).
74+
test.Eventually(PyTorchJob(test, namespace.Name, "pytorch-ddp")).
7675
Should(WithTransform(PyTorchJobConditionSucceeded, Equal(v1.ConditionTrue)))
7776

7877
// TODO: write torch job logs?
79-
time.Sleep(60 * time.Second)
78+
// time.Sleep(60 * time.Second)
8079
}
8180

8281
func readMnistScriptTemplate(test Test, filePath string) []byte {
Lines changed: 5 additions & 278 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,11 @@
11
def train_func():
2-
import os
3-
import torch
4-
import torch.distributed as dist
5-
import torch.nn as nn
6-
import torch.optim as optim
7-
from torchvision import datasets, transforms
8-
from torch.utils.data import DataLoader, DistributedSampler
9-
10-
# Initialize distributed process group
11-
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo")
12-
rank = dist.get_rank()
13-
world_size = dist.get_world_size()
14-
local_rank = int(os.getenv("LOCAL_RANK", 0))
15-
torch.cuda.set_device(local_rank)
16-
17-
# Configuration
18-
batch_size = 64
19-
epochs = 5
20-
learning_rate = 0.01
21-
22-
# Dataset and DataLoader
23-
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
24-
train_dataset = datasets.MNIST(root="/tmp/datasets/mnist", train=True, download=True, transform=transform)
25-
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
26-
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
27-
28-
# Model, Loss, and Optimizer
29-
model = nn.Sequential(
30-
nn.Flatten(),
31-
nn.Linear(28 * 28, 128),
32-
nn.ReLU(),
33-
nn.Linear(128, 10)
34-
).cuda(local_rank)
35-
36-
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
37-
criterion = nn.CrossEntropyLoss().cuda(local_rank)
38-
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
39-
40-
# Training loop
41-
for epoch in range(epochs):
42-
model.train()
43-
epoch_loss = 0
44-
for batch_idx, (data, target) in enumerate(train_loader):
45-
data, target = data.cuda(local_rank, non_blocking=True), target.cuda(local_rank, non_blocking=True)
46-
47-
optimizer.zero_grad()
48-
output = model(data)
49-
loss = criterion(output, target)
50-
loss.backward()
51-
optimizer.step()
52-
53-
epoch_loss += loss.item()
54-
55-
# Log epoch stats
56-
print(f"Rank {rank} | Epoch {epoch + 1}/{epochs} | Loss: {epoch_loss / len(train_loader)}")
57-
58-
# Cleanup
59-
dist.destroy_process_group()
60-
61-
def train_func_2():
622
import os
633
import torch
644
import torch.nn.functional as F
655
from torch.utils.data import DistributedSampler
666
from torchvision import datasets, transforms
677
import torch.distributed as dist
8+
from pathlib import Path
689

6910
# [1] Setup PyTorch DDP. Distributed environment will be set automatically by Training Operator.
7011
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo")
@@ -98,14 +39,15 @@ def forward(self, x):
9839
return F.log_softmax(x, dim=1)
9940

10041
# [3] Attach model to the correct GPU device and distributor.
101-
device = torch.device(f"cuda:{local_rank}")
42+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
10243
model = Net().to(device)
10344
model = Distributor(model)
10445
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
10546

10647
# [4] Setup FashionMNIST dataloader and distribute data across PyTorchJob workers.
48+
Path(f"./data{local_rank}").mkdir(exist_ok=True)
10749
dataset = datasets.FashionMNIST(
108-
"./data",
50+
f"./data{local_rank}",
10951
download=True,
11052
train=True,
11153
transform=transforms.Compose([transforms.ToTensor()]),
@@ -139,219 +81,4 @@ def forward(self, x):
13981
)
14082
)
14183

142-
def train_func_3():
143-
import os
144-
145-
import torch
146-
import requests
147-
from pytorch_lightning import LightningModule, Trainer
148-
from pytorch_lightning.callbacks.progress import TQDMProgressBar
149-
from torch import nn
150-
from torch.nn import functional as F
151-
from torch.utils.data import DataLoader, random_split, RandomSampler
152-
from torchmetrics import Accuracy
153-
from torchvision import transforms
154-
from torchvision.datasets import MNIST
155-
import gzip
156-
import shutil
157-
from minio import Minio
158-
159-
160-
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
161-
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
162-
163-
local_mnist_path = os.path.dirname(os.path.abspath(__file__))
164-
165-
print("prior to running the trainer")
166-
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
167-
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
168-
169-
170-
STORAGE_BUCKET_EXISTS = "{{.StorageBucketDefaultEndpointExists}}"
171-
print("STORAGE_BUCKET_EXISTS: ",STORAGE_BUCKET_EXISTS)
172-
print(f"{'Storage_Bucket_Default_Endpoint : is {{.StorageBucketDefaultEndpoint}}' if '{{.StorageBucketDefaultEndpointExists}}' == 'true' else ''}")
173-
print(f"{'Storage_Bucket_Name : is {{.StorageBucketName}}' if '{{.StorageBucketNameExists}}' == 'true' else ''}")
174-
print(f"{'Storage_Bucket_Mnist_Directory : is {{.StorageBucketMnistDir}}' if '{{.StorageBucketMnistDirExists}}' == 'true' else ''}")
175-
176-
class LitMNIST(LightningModule):
177-
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
178-
super().__init__()
179-
180-
# Set our init args as class attributes
181-
self.data_dir = data_dir
182-
self.hidden_size = hidden_size
183-
self.learning_rate = learning_rate
184-
185-
# Hardcode some dataset specific attributes
186-
self.num_classes = 10
187-
self.dims = (1, 28, 28)
188-
channels, width, height = self.dims
189-
self.transform = transforms.Compose(
190-
[
191-
transforms.ToTensor(),
192-
transforms.Normalize((0.1307,), (0.3081,)),
193-
]
194-
)
195-
196-
# Define PyTorch model
197-
self.model = nn.Sequential(
198-
nn.Flatten(),
199-
nn.Linear(channels * width * height, hidden_size),
200-
nn.ReLU(),
201-
nn.Dropout(0.1),
202-
nn.Linear(hidden_size, hidden_size),
203-
nn.ReLU(),
204-
nn.Dropout(0.1),
205-
nn.Linear(hidden_size, self.num_classes),
206-
)
207-
208-
self.val_accuracy = Accuracy()
209-
self.test_accuracy = Accuracy()
210-
211-
def forward(self, x):
212-
x = self.model(x)
213-
return F.log_softmax(x, dim=1)
214-
215-
def training_step(self, batch, batch_idx):
216-
x, y = batch
217-
logits = self(x)
218-
loss = F.nll_loss(logits, y)
219-
return loss
220-
221-
def validation_step(self, batch, batch_idx):
222-
x, y = batch
223-
logits = self(x)
224-
loss = F.nll_loss(logits, y)
225-
preds = torch.argmax(logits, dim=1)
226-
self.val_accuracy.update(preds, y)
227-
228-
# Calling self.log will surface up scalars for you in TensorBoard
229-
self.log("val_loss", loss, prog_bar=True)
230-
self.log("val_acc", self.val_accuracy, prog_bar=True)
231-
232-
def test_step(self, batch, batch_idx):
233-
x, y = batch
234-
logits = self(x)
235-
loss = F.nll_loss(logits, y)
236-
preds = torch.argmax(logits, dim=1)
237-
self.test_accuracy.update(preds, y)
238-
239-
# Calling self.log will surface up scalars for you in TensorBoard
240-
self.log("test_loss", loss, prog_bar=True)
241-
self.log("test_acc", self.test_accuracy, prog_bar=True)
242-
243-
def configure_optimizers(self):
244-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
245-
return optimizer
246-
247-
####################
248-
# DATA RELATED HOOKS
249-
####################
250-
251-
def prepare_data(self):
252-
# download
253-
print("Downloading MNIST dataset...")
254-
255-
if "{{.StorageBucketDefaultEndpointExists}}" == "true" and "{{.StorageBucketDefaultEndpoint}}" != "":
256-
print("Using storage bucket to download datasets...")
257-
dataset_dir = os.path.join(self.data_dir, "MNIST/raw")
258-
endpoint = "{{.StorageBucketDefaultEndpoint}}"
259-
access_key = "{{.StorageBucketAccessKeyId}}"
260-
secret_key = "{{.StorageBucketSecretKey}}"
261-
bucket_name = "{{.StorageBucketName}}"
262-
263-
# remove prefix if specified in storage bucket endpoint url
264-
secure = True
265-
if endpoint.startswith("https://"):
266-
endpoint = endpoint[len("https://") :]
267-
elif endpoint.startswith("http://"):
268-
endpoint = endpoint[len("http://") :]
269-
secure = False
270-
271-
client = Minio(
272-
endpoint,
273-
access_key=access_key,
274-
secret_key=secret_key,
275-
cert_check=False,
276-
secure=secure
277-
)
278-
279-
if not os.path.exists(dataset_dir):
280-
os.makedirs(dataset_dir)
281-
else:
282-
print(f"Directory '{dataset_dir}' already exists")
283-
284-
# To download datasets from storage bucket's specific directory, use prefix to provide directory name
285-
prefix="{{.StorageBucketMnistDir}}"
286-
# download all files from prefix folder of storage bucket recursively
287-
for item in client.list_objects(
288-
bucket_name, prefix=prefix, recursive=True
289-
):
290-
file_name=item.object_name[len(prefix)+1:]
291-
dataset_file_path = os.path.join(dataset_dir, file_name)
292-
print(dataset_file_path)
293-
if not os.path.exists(dataset_file_path):
294-
client.fget_object(
295-
bucket_name, item.object_name, dataset_file_path
296-
)
297-
else:
298-
print(f"File-path '{dataset_file_path}' already exists")
299-
# Unzip files
300-
with gzip.open(dataset_file_path, "rb") as f_in:
301-
with open(dataset_file_path.split(".")[:-1][0], "wb") as f_out:
302-
shutil.copyfileobj(f_in, f_out)
303-
# delete zip file
304-
os.remove(dataset_file_path)
305-
download_datasets = False
306-
307-
else:
308-
print("Using default MNIST mirror reference to download datasets...")
309-
download_datasets = True
310-
311-
MNIST(self.data_dir, train=True, download=download_datasets)
312-
MNIST(self.data_dir, train=False, download=download_datasets)
313-
314-
def setup(self, stage=None):
315-
316-
# Assign train/val datasets for use in dataloaders
317-
if stage == "fit" or stage is None:
318-
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
319-
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
320-
321-
# Assign test dataset for use in dataloader(s)
322-
if stage == "test" or stage is None:
323-
self.mnist_test = MNIST(
324-
self.data_dir, train=False, transform=self.transform
325-
)
326-
327-
def train_dataloader(self):
328-
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))
329-
330-
def val_dataloader(self):
331-
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
332-
333-
def test_dataloader(self):
334-
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
335-
336-
337-
# Init DataLoader from MNIST Dataset
338-
339-
model = LitMNIST(data_dir=local_mnist_path)
340-
341-
print("GROUP: ", int(os.environ.get("GROUP_WORLD_SIZE", 1)))
342-
print("LOCAL: ", int(os.environ.get("LOCAL_WORLD_SIZE", 1)))
343-
344-
# Initialize a trainer
345-
trainer = Trainer(
346-
accelerator="has to be specified",
347-
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
348-
max_epochs=3,
349-
callbacks=[TQDMProgressBar(refresh_rate=20)],
350-
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
351-
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
352-
replace_sampler_ddp=False,
353-
strategy="ddp",
354-
)
355-
356-
# Train the model ⚡
357-
trainer.fit(model)
84+
dist.destroy_process_group()

0 commit comments

Comments
 (0)