Skip to content

Commit 9be537a

Browse files
add mnist function to py training file
Signed-off-by: Kevin <[email protected]>
1 parent 0042145 commit 9be537a

File tree

3 files changed

+249
-5
lines changed

3 files changed

+249
-5
lines changed

tests/kfto/kfto_mnist_sdk_test.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,38 @@ func TestMnistSDK(t *testing.T) {
8080
}
8181

8282
func readMnistScriptTemplate(test Test, filePath string) []byte {
83+
// Read the mnist.py from resources and perform replacements for custom values using go template
84+
storage_bucket_endpoint, storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint()
85+
storage_bucket_access_key_id, storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId()
86+
storage_bucket_secret_key, storage_bucket_secret_key_exists := GetStorageBucketSecretKey()
87+
storage_bucket_name, storage_bucket_name_exists := GetStorageBucketName()
88+
storage_bucket_mnist_dir, storage_bucket_mnist_dir_exists := GetStorageBucketMnistDir()
89+
90+
props := struct {
91+
StorageBucketDefaultEndpoint string
92+
StorageBucketDefaultEndpointExists bool
93+
StorageBucketAccessKeyId string
94+
StorageBucketAccessKeyIdExists bool
95+
StorageBucketSecretKey string
96+
StorageBucketSecretKeyExists bool
97+
StorageBucketName string
98+
StorageBucketNameExists bool
99+
StorageBucketMnistDir string
100+
StorageBucketMnistDirExists bool
101+
}{
102+
StorageBucketDefaultEndpoint: storage_bucket_endpoint,
103+
StorageBucketDefaultEndpointExists: storage_bucket_endpoint_exists,
104+
StorageBucketAccessKeyId: storage_bucket_access_key_id,
105+
StorageBucketAccessKeyIdExists: storage_bucket_access_key_id_exists,
106+
StorageBucketSecretKey: storage_bucket_secret_key,
107+
StorageBucketSecretKeyExists: storage_bucket_secret_key_exists,
108+
StorageBucketName: storage_bucket_name,
109+
StorageBucketNameExists: storage_bucket_name_exists,
110+
StorageBucketMnistDir: storage_bucket_mnist_dir,
111+
StorageBucketMnistDirExists: storage_bucket_mnist_dir_exists,
112+
}
83113
template, err := files.ReadFile(filePath)
84114
test.Expect(err).NotTo(HaveOccurred())
85115

86-
props := struct{}{}
87-
88116
return ParseTemplate(test, template, props)
89117
}

tests/kfto/resources/kfto_sdk_mnist.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,220 @@ def forward(self, x):
138138
loss.item(),
139139
)
140140
)
141+
142+
def train_func_3():
143+
import os
144+
145+
import torch
146+
import requests
147+
from pytorch_lightning import LightningModule, Trainer
148+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
149+
from torch import nn
150+
from torch.nn import functional as F
151+
from torch.utils.data import DataLoader, random_split, RandomSampler
152+
from torchmetrics import Accuracy
153+
from torchvision import transforms
154+
from torchvision.datasets import MNIST
155+
import gzip
156+
import shutil
157+
from minio import Minio
158+
159+
160+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
161+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
162+
163+
local_mnist_path = os.path.dirname(os.path.abspath(__file__))
164+
165+
print("prior to running the trainer")
166+
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
167+
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
168+
169+
170+
STORAGE_BUCKET_EXISTS = "{{.StorageBucketDefaultEndpointExists}}"
171+
print("STORAGE_BUCKET_EXISTS: ",STORAGE_BUCKET_EXISTS)
172+
print(f"{'Storage_Bucket_Default_Endpoint : is {{.StorageBucketDefaultEndpoint}}' if '{{.StorageBucketDefaultEndpointExists}}' == 'true' else ''}")
173+
print(f"{'Storage_Bucket_Name : is {{.StorageBucketName}}' if '{{.StorageBucketNameExists}}' == 'true' else ''}")
174+
print(f"{'Storage_Bucket_Mnist_Directory : is {{.StorageBucketMnistDir}}' if '{{.StorageBucketMnistDirExists}}' == 'true' else ''}")
175+
176+
class LitMNIST(LightningModule):
177+
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
178+
super().__init__()
179+
180+
# Set our init args as class attributes
181+
self.data_dir = data_dir
182+
self.hidden_size = hidden_size
183+
self.learning_rate = learning_rate
184+
185+
# Hardcode some dataset specific attributes
186+
self.num_classes = 10
187+
self.dims = (1, 28, 28)
188+
channels, width, height = self.dims
189+
self.transform = transforms.Compose(
190+
[
191+
transforms.ToTensor(),
192+
transforms.Normalize((0.1307,), (0.3081,)),
193+
]
194+
)
195+
196+
# Define PyTorch model
197+
self.model = nn.Sequential(
198+
nn.Flatten(),
199+
nn.Linear(channels * width * height, hidden_size),
200+
nn.ReLU(),
201+
nn.Dropout(0.1),
202+
nn.Linear(hidden_size, hidden_size),
203+
nn.ReLU(),
204+
nn.Dropout(0.1),
205+
nn.Linear(hidden_size, self.num_classes),
206+
)
207+
208+
self.val_accuracy = Accuracy()
209+
self.test_accuracy = Accuracy()
210+
211+
def forward(self, x):
212+
x = self.model(x)
213+
return F.log_softmax(x, dim=1)
214+
215+
def training_step(self, batch, batch_idx):
216+
x, y = batch
217+
logits = self(x)
218+
loss = F.nll_loss(logits, y)
219+
return loss
220+
221+
def validation_step(self, batch, batch_idx):
222+
x, y = batch
223+
logits = self(x)
224+
loss = F.nll_loss(logits, y)
225+
preds = torch.argmax(logits, dim=1)
226+
self.val_accuracy.update(preds, y)
227+
228+
# Calling self.log will surface up scalars for you in TensorBoard
229+
self.log("val_loss", loss, prog_bar=True)
230+
self.log("val_acc", self.val_accuracy, prog_bar=True)
231+
232+
def test_step(self, batch, batch_idx):
233+
x, y = batch
234+
logits = self(x)
235+
loss = F.nll_loss(logits, y)
236+
preds = torch.argmax(logits, dim=1)
237+
self.test_accuracy.update(preds, y)
238+
239+
# Calling self.log will surface up scalars for you in TensorBoard
240+
self.log("test_loss", loss, prog_bar=True)
241+
self.log("test_acc", self.test_accuracy, prog_bar=True)
242+
243+
def configure_optimizers(self):
244+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
245+
return optimizer
246+
247+
####################
248+
# DATA RELATED HOOKS
249+
####################
250+
251+
def prepare_data(self):
252+
# download
253+
print("Downloading MNIST dataset...")
254+
255+
if "{{.StorageBucketDefaultEndpointExists}}" == "true" and "{{.StorageBucketDefaultEndpoint}}" != "":
256+
print("Using storage bucket to download datasets...")
257+
dataset_dir = os.path.join(self.data_dir, "MNIST/raw")
258+
endpoint = "{{.StorageBucketDefaultEndpoint}}"
259+
access_key = "{{.StorageBucketAccessKeyId}}"
260+
secret_key = "{{.StorageBucketSecretKey}}"
261+
bucket_name = "{{.StorageBucketName}}"
262+
263+
# remove prefix if specified in storage bucket endpoint url
264+
secure = True
265+
if endpoint.startswith("https://"):
266+
endpoint = endpoint[len("https://") :]
267+
elif endpoint.startswith("http://"):
268+
endpoint = endpoint[len("http://") :]
269+
secure = False
270+
271+
client = Minio(
272+
endpoint,
273+
access_key=access_key,
274+
secret_key=secret_key,
275+
cert_check=False,
276+
secure=secure
277+
)
278+
279+
if not os.path.exists(dataset_dir):
280+
os.makedirs(dataset_dir)
281+
else:
282+
print(f"Directory '{dataset_dir}' already exists")
283+
284+
# To download datasets from storage bucket's specific directory, use prefix to provide directory name
285+
prefix="{{.StorageBucketMnistDir}}"
286+
# download all files from prefix folder of storage bucket recursively
287+
for item in client.list_objects(
288+
bucket_name, prefix=prefix, recursive=True
289+
):
290+
file_name=item.object_name[len(prefix)+1:]
291+
dataset_file_path = os.path.join(dataset_dir, file_name)
292+
print(dataset_file_path)
293+
if not os.path.exists(dataset_file_path):
294+
client.fget_object(
295+
bucket_name, item.object_name, dataset_file_path
296+
)
297+
else:
298+
print(f"File-path '{dataset_file_path}' already exists")
299+
# Unzip files
300+
with gzip.open(dataset_file_path, "rb") as f_in:
301+
with open(dataset_file_path.split(".")[:-1][0], "wb") as f_out:
302+
shutil.copyfileobj(f_in, f_out)
303+
# delete zip file
304+
os.remove(dataset_file_path)
305+
download_datasets = False
306+
307+
else:
308+
print("Using default MNIST mirror reference to download datasets...")
309+
download_datasets = True
310+
311+
MNIST(self.data_dir, train=True, download=download_datasets)
312+
MNIST(self.data_dir, train=False, download=download_datasets)
313+
314+
def setup(self, stage=None):
315+
316+
# Assign train/val datasets for use in dataloaders
317+
if stage == "fit" or stage is None:
318+
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
319+
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
320+
321+
# Assign test dataset for use in dataloader(s)
322+
if stage == "test" or stage is None:
323+
self.mnist_test = MNIST(
324+
self.data_dir, train=False, transform=self.transform
325+
)
326+
327+
def train_dataloader(self):
328+
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))
329+
330+
def val_dataloader(self):
331+
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
332+
333+
def test_dataloader(self):
334+
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
335+
336+
337+
# Init DataLoader from MNIST Dataset
338+
339+
model = LitMNIST(data_dir=local_mnist_path)
340+
341+
print("GROUP: ", int(os.environ.get("GROUP_WORLD_SIZE", 1)))
342+
print("LOCAL: ", int(os.environ.get("LOCAL_WORLD_SIZE", 1)))
343+
344+
# Initialize a trainer
345+
trainer = Trainer(
346+
accelerator="has to be specified",
347+
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
348+
max_epochs=3,
349+
callbacks=[TQDMProgressBar(refresh_rate=20)],
350+
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
351+
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
352+
replace_sampler_ddp=False,
353+
strategy="ddp",
354+
)
355+
356+
# Train the model ⚡
357+
trainer.fit(model)

tests/kfto/resources/mnist_kfto.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
},
1010
"outputs": [],
1111
"source": [
12-
"from kfto_sdk_mnist import train_func_2\n",
12+
"from kfto_sdk_mnist import train_func_2, train_func_3\n",
1313
"from kubeflow.training import TrainingClient\n",
1414
"from kubernetes import client\n",
1515
"import time"
@@ -24,7 +24,6 @@
2424
"source": [
2525
"# parameters\n",
2626
"num_gpus = \"${num_gpus}\"\n",
27-
"train_function = \"${train_function}\"\n",
2827
"openshift_api_url = \"${api_url}\"\n",
2928
"namespace = \"${namespace}\"\n",
3029
"token = \"${token}\"\n",
@@ -66,7 +65,7 @@
6665
"client.create_job(\n",
6766
" name=\"pytorch-ddp\",\n",
6867
" namespace=namespace,\n",
69-
" train_func=train_function,\n",
68+
" train_func=train_func_3,\n",
7069
" num_workers=2,\n",
7170
" resources_per_worker={\"gpu\": num_gpus},\n",
7271
")"

0 commit comments

Comments
 (0)