Skip to content

Commit d34ad1e

Browse files
test: initial implementation of SDK e2e
1 parent d0f081a commit d34ad1e

File tree

6 files changed

+560
-0
lines changed

6 files changed

+560
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
bin/*
2+
.vscode/*

tests/trainer/kubeflow_sdk_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
Copyright 2025.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package trainer
18+
19+
import (
20+
"testing"
21+
22+
. "github.com/opendatahub-io/distributed-workloads/tests/common"
23+
sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests"
24+
)
25+
26+
func TestKubeflowSDK_Sanity(t *testing.T) {
27+
Tags(t, Sanity)
28+
sdktests.RunFashionMnistCpuDistributedTraining(t)
29+
// ADD MORE SANITY TESTS HERE
30+
}
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"execution": {
8+
"iopub.execute_input": "2025-09-03T13:19:46.917723Z",
9+
"iopub.status.busy": "2025-09-03T13:19:46.917308Z",
10+
"iopub.status.idle": "2025-09-03T13:19:46.935181Z",
11+
"shell.execute_reply": "2025-09-03T13:19:46.934697Z",
12+
"shell.execute_reply.started": "2025-09-03T13:19:46.917698Z"
13+
}
14+
},
15+
"outputs": [],
16+
"source": [
17+
"def train_fashion_mnist():\n",
18+
" import os\n",
19+
"\n",
20+
" import torch\n",
21+
" import torch.distributed as dist\n",
22+
" import torch.nn.functional as F\n",
23+
" from torch import nn\n",
24+
" from torch.utils.data import DataLoader, DistributedSampler\n",
25+
" from torchvision import datasets, transforms\n",
26+
"\n",
27+
" # Define the PyTorch CNN model to be trained\n",
28+
" class Net(nn.Module):\n",
29+
" def __init__(self):\n",
30+
" super(Net, self).__init__()\n",
31+
" self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
32+
" self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
33+
" self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
34+
" self.fc2 = nn.Linear(500, 10)\n",
35+
"\n",
36+
" def forward(self, x):\n",
37+
" x = F.relu(self.conv1(x))\n",
38+
" x = F.max_pool2d(x, 2, 2)\n",
39+
" x = F.relu(self.conv2(x))\n",
40+
" x = F.max_pool2d(x, 2, 2)\n",
41+
" x = x.view(-1, 4 * 4 * 50)\n",
42+
" x = F.relu(self.fc1(x))\n",
43+
" x = self.fc2(x)\n",
44+
" return F.log_softmax(x, dim=1)\n",
45+
"\n",
46+
" # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n",
47+
" device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n",
48+
" print(f\"Using Device: {device}, Backend: {backend}\")\n",
49+
"\n",
50+
" # Setup PyTorch distributed.\n",
51+
" local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n",
52+
" dist.init_process_group(backend=backend)\n",
53+
" print(\n",
54+
" \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n",
55+
" dist.get_world_size(),\n",
56+
" dist.get_rank(),\n",
57+
" local_rank,\n",
58+
" )\n",
59+
" )\n",
60+
"\n",
61+
" # Create the model and load it into the device.\n",
62+
" device = torch.device(f\"{device}:{local_rank}\")\n",
63+
" model = nn.parallel.DistributedDataParallel(Net().to(device))\n",
64+
" optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n",
65+
"\n",
66+
" \n",
67+
" # Download FashionMNIST dataset only on local_rank=0 process.\n",
68+
" if local_rank == 0:\n",
69+
" dataset = datasets.FashionMNIST(\n",
70+
" \"./data\",\n",
71+
" train=True,\n",
72+
" download=True,\n",
73+
" transform=transforms.Compose([transforms.ToTensor()]),\n",
74+
" )\n",
75+
" dist.barrier()\n",
76+
" dataset = datasets.FashionMNIST(\n",
77+
" \"./data\",\n",
78+
" train=True,\n",
79+
" download=False,\n",
80+
" transform=transforms.Compose([transforms.ToTensor()]),\n",
81+
" )\n",
82+
"\n",
83+
"\n",
84+
" # Shard the dataset accross workers.\n",
85+
" train_loader = DataLoader(\n",
86+
" dataset,\n",
87+
" batch_size=100,\n",
88+
" sampler=DistributedSampler(dataset)\n",
89+
" )\n",
90+
"\n",
91+
" # TODO(astefanutti): add parameters to the training function\n",
92+
" dist.barrier()\n",
93+
" for epoch in range(1, 3):\n",
94+
" model.train()\n",
95+
"\n",
96+
" # Iterate over mini-batches from the training set\n",
97+
" for batch_idx, (inputs, labels) in enumerate(train_loader):\n",
98+
" # Copy the data to the GPU device if available\n",
99+
" inputs, labels = inputs.to(device), labels.to(device)\n",
100+
" # Forward pass\n",
101+
" outputs = model(inputs)\n",
102+
" loss = F.nll_loss(outputs, labels)\n",
103+
" # Backward pass\n",
104+
" optimizer.zero_grad()\n",
105+
" loss.backward()\n",
106+
" optimizer.step()\n",
107+
"\n",
108+
" if batch_idx % 10 == 0 and dist.get_rank() == 0:\n",
109+
" print(\n",
110+
" \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
111+
" epoch,\n",
112+
" batch_idx * len(inputs),\n",
113+
" len(train_loader.dataset),\n",
114+
" 100.0 * batch_idx / len(train_loader),\n",
115+
" loss.item(),\n",
116+
" )\n",
117+
" )\n",
118+
"\n",
119+
" # Wait for the distributed training to complete\n",
120+
" dist.barrier()\n",
121+
" if dist.get_rank() == 0:\n",
122+
" print(\"Training is finished\")\n",
123+
"\n",
124+
" # Finally clean up PyTorch distributed\n",
125+
" dist.destroy_process_group()"
126+
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"metadata": {
132+
"execution": {
133+
"iopub.execute_input": "2025-09-03T13:19:49.832393Z",
134+
"iopub.status.busy": "2025-09-03T13:19:49.832117Z",
135+
"iopub.status.idle": "2025-09-03T13:19:51.924613Z",
136+
"shell.execute_reply": "2025-09-03T13:19:51.924264Z",
137+
"shell.execute_reply.started": "2025-09-03T13:19:49.832371Z"
138+
},
139+
"pycharm": {
140+
"name": "#%%\n"
141+
}
142+
},
143+
"outputs": [],
144+
"source": [
145+
"from kubeflow.trainer import CustomTrainer, TrainerClient\n",
146+
"\n",
147+
"client = TrainerClient()\n"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": null,
153+
"metadata": {},
154+
"outputs": [],
155+
"source": [
156+
"for runtime in client.list_runtimes():\n",
157+
" print(runtime)\n",
158+
" if runtime.name == \"universal\": # Update to actual universal image runtime once available\n",
159+
" torch_runtime = runtime"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": null,
165+
"metadata": {
166+
"execution": {
167+
"iopub.execute_input": "2025-09-03T13:19:56.525591Z",
168+
"iopub.status.busy": "2025-09-03T13:19:56.524936Z",
169+
"iopub.status.idle": "2025-09-03T13:19:56.721404Z",
170+
"shell.execute_reply": "2025-09-03T13:19:56.720565Z",
171+
"shell.execute_reply.started": "2025-09-03T13:19:56.525536Z"
172+
}
173+
},
174+
"outputs": [],
175+
"source": [
176+
"job_name = client.train(\n",
177+
" trainer=CustomTrainer(\n",
178+
" func=train_fashion_mnist,\n",
179+
" num_nodes=2,\n",
180+
" resources_per_node={\n",
181+
" \"cpu\": 2,\n",
182+
" \"memory\": \"8Gi\",\n",
183+
" },\n",
184+
" packages_to_install=[\"torchvision\"],\n",
185+
" ),\n",
186+
" runtime=torch_runtime,\n",
187+
")"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"execution": {
195+
"iopub.execute_input": "2025-09-03T13:20:01.378158Z",
196+
"iopub.status.busy": "2025-09-03T13:20:01.377707Z",
197+
"iopub.status.idle": "2025-09-03T13:20:12.713960Z",
198+
"shell.execute_reply": "2025-09-03T13:20:12.713295Z",
199+
"shell.execute_reply.started": "2025-09-03T13:20:01.378130Z"
200+
}
201+
},
202+
"outputs": [],
203+
"source": [
204+
"# Wait for the running status.\n",
205+
"client.wait_for_job_status(name=job_name, status={\"Running\"})"
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": null,
211+
"metadata": {
212+
"execution": {
213+
"iopub.execute_input": "2025-09-03T13:20:24.045774Z",
214+
"iopub.status.busy": "2025-09-03T13:20:24.045480Z",
215+
"iopub.status.idle": "2025-09-03T13:20:24.772877Z",
216+
"shell.execute_reply": "2025-09-03T13:20:24.772178Z",
217+
"shell.execute_reply.started": "2025-09-03T13:20:24.045755Z"
218+
}
219+
},
220+
"outputs": [],
221+
"source": [
222+
"for c in client.get_job(name=job_name).steps:\n",
223+
" print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")"
224+
]
225+
},
226+
{
227+
"cell_type": "code",
228+
"execution_count": null,
229+
"metadata": {
230+
"execution": {
231+
"iopub.execute_input": "2025-09-03T13:20:26.729486Z",
232+
"iopub.status.busy": "2025-09-03T13:20:26.728951Z",
233+
"iopub.status.idle": "2025-09-03T13:20:29.596510Z",
234+
"shell.execute_reply": "2025-09-03T13:20:29.594741Z",
235+
"shell.execute_reply.started": "2025-09-03T13:20:26.729446Z"
236+
}
237+
},
238+
"outputs": [],
239+
"source": [
240+
"for logline in client.get_job_logs(job_name, follow=True):\n",
241+
" print(logline)"
242+
]
243+
},
244+
{
245+
"cell_type": "code",
246+
"execution_count": null,
247+
"metadata": {},
248+
"outputs": [],
249+
"source": [
250+
"client.delete_job(job_name)"
251+
]
252+
}
253+
],
254+
"metadata": {
255+
"kernelspec": {
256+
"display_name": "Python 3 (ipykernel)",
257+
"language": "python",
258+
"name": "python3"
259+
},
260+
"language_info": {
261+
"codemirror_mode": {
262+
"name": "ipython",
263+
"version": 3
264+
},
265+
"file_extension": ".py",
266+
"mimetype": "text/x-python",
267+
"name": "python",
268+
"nbconvert_exporter": "python",
269+
"pygments_lexer": "ipython3",
270+
"version": "3.11.13"
271+
}
272+
},
273+
"nbformat": 4,
274+
"nbformat_minor": 4
275+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
Copyright 2025.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package sdk_tests
18+
19+
import (
20+
"fmt"
21+
"os"
22+
"testing"
23+
24+
. "github.com/onsi/gomega"
25+
26+
corev1 "k8s.io/api/core/v1"
27+
28+
common "github.com/opendatahub-io/distributed-workloads/tests/common"
29+
support "github.com/opendatahub-io/distributed-workloads/tests/common/support"
30+
trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils"
31+
)
32+
33+
const (
34+
notebookName = "mnist.ipynb"
35+
notebookPath = "resources/" + notebookName
36+
)
37+
38+
// CPU Only - Distributed Training
39+
func RunFashionMnistCpuDistributedTraining(t *testing.T) {
40+
test := support.With(t)
41+
42+
// Create a new test namespace
43+
namespace := test.NewTestNamespace()
44+
45+
// Ensure pre-requisites to run the test are met
46+
trainerutils.EnsureTrainerClusterReady(t, test)
47+
48+
// Ensure Notebook SA and RBACs are set for this namespace
49+
trainerutils.EnsureNotebookRBAC(t, test, namespace.Name)
50+
51+
// RBACs setup
52+
userName := common.GetNotebookUserName(test)
53+
userToken := common.GetNotebookUserToken(test)
54+
support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin")
55+
56+
// Read notebook from directory
57+
localPath := notebookPath
58+
nb, err := os.ReadFile(localPath)
59+
test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath))
60+
61+
// Create ConfigMap with notebook
62+
cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: nb})
63+
64+
// Build command
65+
marker := "/opt/app-root/src/notebook_completion_marker"
66+
shellCmd := trainerutils.BuildPapermillShellCmd(notebookName, marker, nil)
67+
command := []string{"/bin/sh", "-c", shellCmd}
68+
69+
// Create Notebook CR (with default 10Gi PVC)
70+
pvc := support.CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", support.AccessModes(corev1.ReadWriteOnce))
71+
common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, pvc, common.ContainerSizeSmall)
72+
73+
// Cleanup
74+
defer func() {
75+
common.DeleteNotebook(test, namespace)
76+
test.Eventually(common.Notebooks(test, namespace), support.TestTimeoutLong).Should(HaveLen(0))
77+
}()
78+
79+
// Wait for the Notebook Pod and get pod/container names
80+
podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name)
81+
82+
// Poll marker file to check if the notebook execution completed successfully
83+
if err := trainerutils.PollNotebookCompletionMarker(test, namespace.Name, podName, containerName, marker, support.TestTimeoutDouble); err != nil {
84+
test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE")
85+
}
86+
}

0 commit comments

Comments
 (0)