Skip to content

Commit c16fd7f

Browse files
template bucket info for MNIST dataset
Signed-off-by: Kevin <[email protected]>
1 parent 17c1e64 commit c16fd7f

File tree

2 files changed

+78
-24
lines changed

2 files changed

+78
-24
lines changed

tests/kfto/resources/kfto_sdk_mnist.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ def train_func():
66
from torchvision import datasets, transforms
77
import torch.distributed as dist
88
from pathlib import Path
9+
from minio import Minio
10+
import shutil
11+
import gzip
12+
913

1014
# [1] Setup PyTorch DDP. Distributed environment will be set automatically by Training Operator.
1115
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo")
@@ -45,13 +49,63 @@ def forward(self, x):
4549
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
4650

4751
# [4] Setup FashionMNIST dataloader and distribute data across PyTorchJob workers.
48-
Path(f"./data{local_rank}").mkdir(exist_ok=True)
49-
dataset = datasets.FashionMNIST(
50-
f"./data{local_rank}",
51-
download=True,
52-
train=True,
53-
transform=transforms.Compose([transforms.ToTensor()]),
54-
)
52+
dataset_path = "./data"
53+
dataset_dir = os.path.join(dataset_path, "MNIST/raw")
54+
with_aws = "{{.StorageBucketNameExists}}"
55+
endpoint = "{{.StorageBucketDefaultEndpoint}}"
56+
access_key = "{{.StorageBucketAccessKeyId}}"
57+
secret_key = "{{.StorageBucketSecretKey}}"
58+
bucket_name = "{{.StorageBucketName}}"
59+
prefix = "{{.StorageBucketMnistDir}}"
60+
if with_aws != "true":
61+
client = Minio(
62+
endpoint,
63+
access_key=access_key,
64+
secret_key=secret_key,
65+
cert_check=False,
66+
secure=False, #TODO
67+
)
68+
69+
if not os.path.exists(dataset_dir):
70+
os.makedirs(dataset_dir)
71+
72+
for item in client.list_objects(
73+
bucket_name, prefix=prefix, recursive=True
74+
):
75+
file_name=item.object_name[len(prefix)+1:]
76+
dataset_file_path = os.path.join(dataset_dir, file_name)
77+
print(f"Downloading dataset file {file_name} to {dataset_file_path}..")
78+
if not os.path.exists(dataset_file_path):
79+
client.fget_object(
80+
bucket_name, item.object_name, dataset_file_path
81+
)
82+
# Unzip files --
83+
## Sample zipfilepath : ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
84+
with gzip.open(dataset_file_path, "rb") as f_in:
85+
filename=file_name.split(".")[0] #-> t10k-images-idx3-ubyte
86+
file_path=("/".join(dataset_file_path.split("/")[:-1])) #->../data/MNIST/raw
87+
full_file_path=os.path.join(file_path,filename) #->../data/MNIST/raw/t10k-images-idx3-ubyte
88+
print(f"Extracting {dataset_file_path} to {file_path}..")
89+
90+
with open(full_file_path, "wb") as f_out:
91+
shutil.copyfileobj(f_in, f_out)
92+
print(f"Dataset file downloaded : {full_file_path}\n")
93+
# delete zip file
94+
os.remove(dataset_file_path)
95+
96+
dataset = datasets.MNIST(
97+
dataset_path,
98+
train=True,
99+
download=False,
100+
transform=transforms.Compose([transforms.ToTensor()]),
101+
)
102+
else:
103+
dataset = datasets.MNIST(
104+
dataset_path,
105+
train=True,
106+
download=True,
107+
transform=transforms.Compose([transforms.ToTensor()]),
108+
)
55109
train_loader = torch.utils.data.DataLoader(
56110
dataset=dataset,
57111
batch_size=128,

tests/kfto/resources/mnist_kfto.ipynb

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 6,
66
"id": "b55bc3ea-4ce3-49bf-bb1f-e209de8ca47a",
77
"metadata": {
88
"tags": []
@@ -19,7 +19,7 @@
1919
},
2020
{
2121
"cell_type": "code",
22-
"execution_count": 2,
22+
"execution_count": 7,
2323
"id": "72dd1751",
2424
"metadata": {},
2525
"outputs": [],
@@ -33,14 +33,15 @@
3333
},
3434
{
3535
"cell_type": "code",
36-
"execution_count": null,
36+
"execution_count": 8,
3737
"id": "4ca70b20",
3838
"metadata": {},
3939
"outputs": [],
4040
"source": [
41-
"config = c.Configuration(host=openshift_api_url, api_key=token)\n",
42-
"config.verify_ssl = False\n",
43-
"tc = TrainingClient()"
41+
"# config = c.Configuration(host=openshift_api_url, api_key=token)\n",
42+
"# config.verify_ssl = False\n",
43+
"tc = TrainingClient()\n",
44+
"# config.api_key_prefix"
4445
]
4546
},
4647
{
@@ -50,17 +51,16 @@
5051
"metadata": {},
5152
"outputs": [],
5253
"source": [
53-
"try:\n",
54-
" tc.create_job(\n",
55-
" name=\"pytorch-ddp\",\n",
56-
" namespace=namespace,\n",
57-
" train_func=train_func,\n",
58-
" num_workers=2,\n",
59-
" resources_per_worker={\"gpu\": num_gpus},\n",
60-
" base_image=\"quay.io/kpostlet/torch-train:withvision\",\n",
61-
" )\n",
62-
"except:\n",
63-
" pass"
54+
"tc.create_job(\n",
55+
" name=\"pytorch-ddp\",\n",
56+
" namespace=namespace,\n",
57+
" train_func=train_func,\n",
58+
" num_workers=2,\n",
59+
" resources_per_worker={\"gpu\": num_gpus},\n",
60+
" base_image=\"quay.io/kpostlet/torch-train:with-minivision\",\n",
61+
" # packages_to_install=[\"torchvision==0.19.0\", \"--target=/tmp/lib\"],\n",
62+
" # env_vars={\"PYTHONPATH\": \"/tmp/lib:$PYTHONPATH\", \"NCCL_DEBUG\": \"INFO\", \"TORCH_DISTRIBUTED_DEBUG\": \"DETAIL\"}\n",
63+
")"
6464
]
6565
},
6666
{

0 commit comments

Comments
 (0)