Skip to content

Commit 39f5ccd

Browse files
address comments and make training CPU only
1 parent dcb6265 commit 39f5ccd

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

tests/trainer/kubeflow_sdk_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import (
2323
sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests"
2424
)
2525

26-
func TestKubeflowSDK_Sanity(t *testing.T) {
26+
func TestKubeflowSdkSanity(t *testing.T) {
2727
Tags(t, Sanity)
2828
sdktests.RunFashionMnistCpuDistributedTraining(t)
2929
}

tests/trainer/resources/mnist.ipynb

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@
4444
" x = self.fc2(x)\n",
4545
" return F.log_softmax(x, dim=1)\n",
4646
"\n",
47-
" # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n",
48-
" device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n",
49-
" print(f\"Using Device: {device}, Backend: {backend}\")\n",
47+
" # Force CPU-only for this test to avoid accidental NCCL/GPU usage\n",
48+
" backend = \"gloo\"\n",
49+
" device = torch.device(\"cpu\")\n",
50+
" print(f\"Using Device: cpu, Backend: {backend}\")\n",
5051
"\n",
5152
" # Setup PyTorch distributed.\n",
52-
" local_rank = int(os.getenv(\"PET_NODE_RANK\", 0))\n",
53+
" local_rank = int(os.getenv(\"LOCAL_RANK\") or os.getenv(\"PET_NODE_RANK\") or 0)\n",
5354
" dist.init_process_group(backend=backend)\n",
5455
" print(\n",
5556
" \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n",
@@ -60,19 +61,16 @@
6061
" )\n",
6162
"\n",
6263
" # Create the model and load it into the device.\n",
63-
" device = torch.device(f\"{device}:{local_rank}\")\n",
6464
" model = nn.parallel.DistributedDataParallel(Net().to(device))\n",
6565
" optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n",
6666
"\n",
67-
" \n",
6867
" # Prefer shared PVC if present; else fallback to internet download (rank 0 only)\n",
6968
" from urllib.parse import urlparse\n",
7069
" import gzip, shutil\n",
71-
" \n",
70+
"\n",
7271
" pvc_root = \"/mnt/shared\"\n",
7372
" pvc_raw = os.path.join(pvc_root, \"FashionMNIST\", \"raw\")\n",
7473
"\n",
75-
"\n",
7674
" use_pvc = os.path.isdir(pvc_raw) and any(os.scandir(pvc_raw))\n",
7775
"\n",
7876
" if not use_pvc:\n",
@@ -126,7 +124,7 @@
126124
"\n",
127125
" # Iterate over mini-batches from the training set\n",
128126
" for batch_idx, (inputs, labels) in enumerate(train_loader):\n",
129-
" # Copy the data to the GPU device if available\n",
127+
" # Move the data to the selected device\n",
130128
" inputs, labels = inputs.to(device), labels.to(device)\n",
131129
" # Forward pass\n",
132130
" outputs = model(inputs)\n",

tests/trainer/sdk_tests/fashion_mnist_tests.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ func RunFashionMnistCpuDistributedTraining(t *testing.T) {
109109
podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name)
110110

111111
// Poll logs to check if the notebook execution completed successfully
112-
if err := trainerutils.PollNotebookLogsForStatus(test, namespace.Name, podName, containerName, support.TestTimeoutDouble); err != nil {
113-
test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE")
114-
}
112+
err = trainerutils.PollNotebookLogsForStatus(test, namespace.Name, podName, containerName, support.TestTimeoutDouble)
113+
test.Expect(err).ShouldNot(HaveOccurred(), "Notebook execution reported FAILURE")
115114

116115
}

0 commit comments

Comments
 (0)