@@ -6,6 +6,10 @@ def train_func():
6
6
from torchvision import datasets , transforms
7
7
import torch .distributed as dist
8
8
from pathlib import Path
9
+ from minio import Minio
10
+ import shutil
11
+ import gzip
12
+
9
13
10
14
# [1] Setup PyTorch DDP. Distributed environment will be set automatically by Training Operator.
11
15
dist .init_process_group (backend = "nccl" if torch .cuda .is_available () else "gloo" )
@@ -45,13 +49,63 @@ def forward(self, x):
45
49
optimizer = torch .optim .SGD (model .parameters (), lr = 0.01 , momentum = 0.5 )
46
50
47
51
# [4] Setup FashionMNIST dataloader and distribute data across PyTorchJob workers.
48
- Path (f"./data{ local_rank } " ).mkdir (exist_ok = True )
49
- dataset = datasets .FashionMNIST (
50
- f"./data{ local_rank } " ,
51
- download = True ,
52
- train = True ,
53
- transform = transforms .Compose ([transforms .ToTensor ()]),
54
- )
52
+ dataset_path = "./data"
53
+ dataset_dir = os .path .join (dataset_path , "MNIST/raw" )
54
+ with_aws = "{{.StorageBucketNameExists}}"
55
+ endpoint = "{{.StorageBucketDefaultEndpoint}}"
56
+ access_key = "{{.StorageBucketAccessKeyId}}"
57
+ secret_key = "{{.StorageBucketSecretKey}}"
58
+ bucket_name = "{{.StorageBucketName}}"
59
+ prefix = "{{.StorageBucketMnistDir}}"
60
+ if with_aws != "true" :
61
+ client = Minio (
62
+ endpoint ,
63
+ access_key = access_key ,
64
+ secret_key = secret_key ,
65
+ cert_check = False ,
66
+ secure = False , #TODO
67
+ )
68
+
69
+ if not os .path .exists (dataset_dir ):
70
+ os .makedirs (dataset_dir )
71
+
72
+ for item in client .list_objects (
73
+ bucket_name , prefix = prefix , recursive = True
74
+ ):
75
+ file_name = item .object_name [len (prefix )+ 1 :]
76
+ dataset_file_path = os .path .join (dataset_dir , file_name )
77
+ print (f"Downloading dataset file { file_name } to { dataset_file_path } .." )
78
+ if not os .path .exists (dataset_file_path ):
79
+ client .fget_object (
80
+ bucket_name , item .object_name , dataset_file_path
81
+ )
82
+ # Unzip files --
83
+ ## Sample zipfilepath : ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
84
+ with gzip .open (dataset_file_path , "rb" ) as f_in :
85
+ filename = file_name .split ("." )[0 ] #-> t10k-images-idx3-ubyte
86
+ file_path = ("/" .join (dataset_file_path .split ("/" )[:- 1 ])) #->../data/MNIST/raw
87
+ full_file_path = os .path .join (file_path ,filename ) #->../data/MNIST/raw/t10k-images-idx3-ubyte
88
+ print (f"Extracting { dataset_file_path } to { file_path } .." )
89
+
90
+ with open (full_file_path , "wb" ) as f_out :
91
+ shutil .copyfileobj (f_in , f_out )
92
+ print (f"Dataset file downloaded : { full_file_path } \n " )
93
+ # delete zip file
94
+ os .remove (dataset_file_path )
95
+
96
+ dataset = datasets .MNIST (
97
+ dataset_path ,
98
+ train = True ,
99
+ download = False ,
100
+ transform = transforms .Compose ([transforms .ToTensor ()]),
101
+ )
102
+ else :
103
+ dataset = datasets .MNIST (
104
+ dataset_path ,
105
+ train = True ,
106
+ download = True ,
107
+ transform = transforms .Compose ([transforms .ToTensor ()]),
108
+ )
55
109
train_loader = torch .utils .data .DataLoader (
56
110
dataset = dataset ,
57
111
batch_size = 128 ,
0 commit comments