@@ -138,3 +138,220 @@ def forward(self, x):
138
138
loss .item (),
139
139
)
140
140
)
141
+
142
+ def train_func_3 ():
143
+ import os
144
+
145
+ import torch
146
+ import requests
147
+ from pytorch_lightning import LightningModule , Trainer
148
+ from pytorch_lightning .callbacks .progress import TQDMProgressBar
149
+ from torch import nn
150
+ from torch .nn import functional as F
151
+ from torch .utils .data import DataLoader , random_split , RandomSampler
152
+ from torchmetrics import Accuracy
153
+ from torchvision import transforms
154
+ from torchvision .datasets import MNIST
155
+ import gzip
156
+ import shutil
157
+ from minio import Minio
158
+
159
+
160
+ PATH_DATASETS = os .environ .get ("PATH_DATASETS" , "." )
161
+ BATCH_SIZE = 256 if torch .cuda .is_available () else 64
162
+
163
+ local_mnist_path = os .path .dirname (os .path .abspath (__file__ ))
164
+
165
+ print ("prior to running the trainer" )
166
+ print ("MASTER_ADDR: is " , os .getenv ("MASTER_ADDR" ))
167
+ print ("MASTER_PORT: is " , os .getenv ("MASTER_PORT" ))
168
+
169
+
170
+ STORAGE_BUCKET_EXISTS = "{{.StorageBucketDefaultEndpointExists}}"
171
+ print ("STORAGE_BUCKET_EXISTS: " ,STORAGE_BUCKET_EXISTS )
172
+ print (f"{ 'Storage_Bucket_Default_Endpoint : is {{.StorageBucketDefaultEndpoint}}' if '{{.StorageBucketDefaultEndpointExists}}' == 'true' else '' } " )
173
+ print (f"{ 'Storage_Bucket_Name : is {{.StorageBucketName}}' if '{{.StorageBucketNameExists}}' == 'true' else '' } " )
174
+ print (f"{ 'Storage_Bucket_Mnist_Directory : is {{.StorageBucketMnistDir}}' if '{{.StorageBucketMnistDirExists}}' == 'true' else '' } " )
175
+
176
+ class LitMNIST (LightningModule ):
177
+ def __init__ (self , data_dir = PATH_DATASETS , hidden_size = 64 , learning_rate = 2e-4 ):
178
+ super ().__init__ ()
179
+
180
+ # Set our init args as class attributes
181
+ self .data_dir = data_dir
182
+ self .hidden_size = hidden_size
183
+ self .learning_rate = learning_rate
184
+
185
+ # Hardcode some dataset specific attributes
186
+ self .num_classes = 10
187
+ self .dims = (1 , 28 , 28 )
188
+ channels , width , height = self .dims
189
+ self .transform = transforms .Compose (
190
+ [
191
+ transforms .ToTensor (),
192
+ transforms .Normalize ((0.1307 ,), (0.3081 ,)),
193
+ ]
194
+ )
195
+
196
+ # Define PyTorch model
197
+ self .model = nn .Sequential (
198
+ nn .Flatten (),
199
+ nn .Linear (channels * width * height , hidden_size ),
200
+ nn .ReLU (),
201
+ nn .Dropout (0.1 ),
202
+ nn .Linear (hidden_size , hidden_size ),
203
+ nn .ReLU (),
204
+ nn .Dropout (0.1 ),
205
+ nn .Linear (hidden_size , self .num_classes ),
206
+ )
207
+
208
+ self .val_accuracy = Accuracy ()
209
+ self .test_accuracy = Accuracy ()
210
+
211
+ def forward (self , x ):
212
+ x = self .model (x )
213
+ return F .log_softmax (x , dim = 1 )
214
+
215
+ def training_step (self , batch , batch_idx ):
216
+ x , y = batch
217
+ logits = self (x )
218
+ loss = F .nll_loss (logits , y )
219
+ return loss
220
+
221
+ def validation_step (self , batch , batch_idx ):
222
+ x , y = batch
223
+ logits = self (x )
224
+ loss = F .nll_loss (logits , y )
225
+ preds = torch .argmax (logits , dim = 1 )
226
+ self .val_accuracy .update (preds , y )
227
+
228
+ # Calling self.log will surface up scalars for you in TensorBoard
229
+ self .log ("val_loss" , loss , prog_bar = True )
230
+ self .log ("val_acc" , self .val_accuracy , prog_bar = True )
231
+
232
+ def test_step (self , batch , batch_idx ):
233
+ x , y = batch
234
+ logits = self (x )
235
+ loss = F .nll_loss (logits , y )
236
+ preds = torch .argmax (logits , dim = 1 )
237
+ self .test_accuracy .update (preds , y )
238
+
239
+ # Calling self.log will surface up scalars for you in TensorBoard
240
+ self .log ("test_loss" , loss , prog_bar = True )
241
+ self .log ("test_acc" , self .test_accuracy , prog_bar = True )
242
+
243
+ def configure_optimizers (self ):
244
+ optimizer = torch .optim .Adam (self .parameters (), lr = self .learning_rate )
245
+ return optimizer
246
+
247
+ ####################
248
+ # DATA RELATED HOOKS
249
+ ####################
250
+
251
+ def prepare_data (self ):
252
+ # download
253
+ print ("Downloading MNIST dataset..." )
254
+
255
+ if "{{.StorageBucketDefaultEndpointExists}}" == "true" and "{{.StorageBucketDefaultEndpoint}}" != "" :
256
+ print ("Using storage bucket to download datasets..." )
257
+ dataset_dir = os .path .join (self .data_dir , "MNIST/raw" )
258
+ endpoint = "{{.StorageBucketDefaultEndpoint}}"
259
+ access_key = "{{.StorageBucketAccessKeyId}}"
260
+ secret_key = "{{.StorageBucketSecretKey}}"
261
+ bucket_name = "{{.StorageBucketName}}"
262
+
263
+ # remove prefix if specified in storage bucket endpoint url
264
+ secure = True
265
+ if endpoint .startswith ("https://" ):
266
+ endpoint = endpoint [len ("https://" ) :]
267
+ elif endpoint .startswith ("http://" ):
268
+ endpoint = endpoint [len ("http://" ) :]
269
+ secure = False
270
+
271
+ client = Minio (
272
+ endpoint ,
273
+ access_key = access_key ,
274
+ secret_key = secret_key ,
275
+ cert_check = False ,
276
+ secure = secure
277
+ )
278
+
279
+ if not os .path .exists (dataset_dir ):
280
+ os .makedirs (dataset_dir )
281
+ else :
282
+ print (f"Directory '{ dataset_dir } ' already exists" )
283
+
284
+ # To download datasets from storage bucket's specific directory, use prefix to provide directory name
285
+ prefix = "{{.StorageBucketMnistDir}}"
286
+ # download all files from prefix folder of storage bucket recursively
287
+ for item in client .list_objects (
288
+ bucket_name , prefix = prefix , recursive = True
289
+ ):
290
+ file_name = item .object_name [len (prefix )+ 1 :]
291
+ dataset_file_path = os .path .join (dataset_dir , file_name )
292
+ print (dataset_file_path )
293
+ if not os .path .exists (dataset_file_path ):
294
+ client .fget_object (
295
+ bucket_name , item .object_name , dataset_file_path
296
+ )
297
+ else :
298
+ print (f"File-path '{ dataset_file_path } ' already exists" )
299
+ # Unzip files
300
+ with gzip .open (dataset_file_path , "rb" ) as f_in :
301
+ with open (dataset_file_path .split ("." )[:- 1 ][0 ], "wb" ) as f_out :
302
+ shutil .copyfileobj (f_in , f_out )
303
+ # delete zip file
304
+ os .remove (dataset_file_path )
305
+ download_datasets = False
306
+
307
+ else :
308
+ print ("Using default MNIST mirror reference to download datasets..." )
309
+ download_datasets = True
310
+
311
+ MNIST (self .data_dir , train = True , download = download_datasets )
312
+ MNIST (self .data_dir , train = False , download = download_datasets )
313
+
314
+ def setup (self , stage = None ):
315
+
316
+ # Assign train/val datasets for use in dataloaders
317
+ if stage == "fit" or stage is None :
318
+ mnist_full = MNIST (self .data_dir , train = True , transform = self .transform )
319
+ self .mnist_train , self .mnist_val = random_split (mnist_full , [55000 , 5000 ])
320
+
321
+ # Assign test dataset for use in dataloader(s)
322
+ if stage == "test" or stage is None :
323
+ self .mnist_test = MNIST (
324
+ self .data_dir , train = False , transform = self .transform
325
+ )
326
+
327
+ def train_dataloader (self ):
328
+ return DataLoader (self .mnist_train , batch_size = BATCH_SIZE , sampler = RandomSampler (self .mnist_train , num_samples = 1000 ))
329
+
330
+ def val_dataloader (self ):
331
+ return DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
332
+
333
+ def test_dataloader (self ):
334
+ return DataLoader (self .mnist_test , batch_size = BATCH_SIZE )
335
+
336
+
337
+ # Init DataLoader from MNIST Dataset
338
+
339
+ model = LitMNIST (data_dir = local_mnist_path )
340
+
341
+ print ("GROUP: " , int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )))
342
+ print ("LOCAL: " , int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )))
343
+
344
+ # Initialize a trainer
345
+ trainer = Trainer (
346
+ accelerator = "has to be specified" ,
347
+ # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
348
+ max_epochs = 3 ,
349
+ callbacks = [TQDMProgressBar (refresh_rate = 20 )],
350
+ num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
351
+ devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
352
+ replace_sampler_ddp = False ,
353
+ strategy = "ddp" ,
354
+ )
355
+
356
+ # Train the model ⚡
357
+ trainer .fit (model )
0 commit comments