@@ -36,30 +36,50 @@ import (
3636
3737func TestPyTorchDDPMultiNodeMultiCPUWithTorchCuda28 (t * testing.T ) {
3838 Tags (t , Tier1 , MultiNode (2 ))
39- runPyTorchDDPMultiNodeJob (t , CPU , GetTrainingCudaPyTorch28Image (), 2 , 2 )
39+ runPyTorchDDPMultiNodeJob (t , CPU , GetTrainingCudaPyTorch28Image (), "resources/requirements.txt" , 2 , 2 )
40+ }
41+
42+ func TestPyTorchDDPSingleNodeSingleGPUWithTorchCuda (t * testing.T ) {
43+ Tags (t , KftoCuda )
44+ runPyTorchDDPMultiNodeJob (t , NVIDIA , GetTrainingCudaPyTorch28Image (), "resources/requirements.txt" , 1 , 1 )
4045}
4146
4247func TestPyTorchDDPSingleNodeMultiGPUWithTorchCuda (t * testing.T ) {
43- Tags (t , KftoCuda , MultiNode (2 ))
44- runPyTorchDDPMultiNodeJob (t , NVIDIA , GetTrainingCudaPyTorch28Image (), 1 , 2 )
48+ Tags (t , KftoCuda )
49+ runPyTorchDDPMultiNodeJob (t , NVIDIA , GetTrainingCudaPyTorch28Image (), "resources/requirements.txt" , 1 , 2 )
50+ }
51+
52+ func TestPyTorchDDPMultiNodeSingleGPUWithTorchCuda (t * testing.T ) {
53+ Tags (t , KftoCuda )
54+ runPyTorchDDPMultiNodeJob (t , NVIDIA , GetTrainingCudaPyTorch28Image (), "resources/requirements.txt" , 2 , 1 )
4555}
4656
4757func TestPyTorchDDPMultiNodeMultiGPUWithTorchCuda (t * testing.T ) {
48- Tags (t , KftoCuda , MultiNode (2 ))
49- runPyTorchDDPMultiNodeJob (t , NVIDIA , GetTrainingCudaPyTorch28Image (), 2 , 2 )
58+ Tags (t , KftoCuda )
59+ runPyTorchDDPMultiNodeJob (t , NVIDIA , GetTrainingCudaPyTorch28Image (), "resources/requirements.txt" , 2 , 2 )
60+ }
61+
62+ func TestPyTorchDDPSingleNodeSingleGPUWithTorchRocm (t * testing.T ) {
63+ Tags (t , KftoRocm )
64+ runPyTorchDDPMultiNodeJob (t , AMD , GetTrainingRocmPyTorch28Image (), "resources/requirements-rocm.txt" , 1 , 1 )
5065}
5166
5267func TestPyTorchDDPSingleNodeMultiGPUWithTorchRocm (t * testing.T ) {
53- Tags (t , KftoRocm , MultiNode (2 ))
54- runPyTorchDDPMultiNodeJob (t , AMD , GetTrainingRocmPyTorch28Image (), 1 , 2 )
68+ Tags (t , KftoRocm )
69+ runPyTorchDDPMultiNodeJob (t , AMD , GetTrainingRocmPyTorch28Image (), "resources/requirements-rocm.txt" , 1 , 2 )
70+ }
71+
72+ func TestPyTorchDDPMultiNodeSingleGPUWithTorchRocm (t * testing.T ) {
73+ Tags (t , KftoRocm )
74+ runPyTorchDDPMultiNodeJob (t , AMD , GetTrainingRocmPyTorch28Image (), "resources/requirements-rocm.txt" , 2 , 1 )
5575}
5676
5777func TestPyTorchDDPMultiNodeMultiGPUWithTorchRocm (t * testing.T ) {
58- Tags (t , KftoRocm , MultiNode ( 2 ) )
59- runPyTorchDDPMultiNodeJob (t , AMD , GetTrainingRocmPyTorch28Image (), 2 , 2 )
78+ Tags (t , KftoRocm )
79+ runPyTorchDDPMultiNodeJob (t , AMD , GetTrainingRocmPyTorch28Image (), "resources/requirements-rocm.txt" , 2 , 2 )
6080}
6181
62- func runPyTorchDDPMultiNodeJob (t * testing.T , accelerator Accelerator , baseImage string , numNodes , numProcPerNode int32 ) {
82+ func runPyTorchDDPMultiNodeJob (t * testing.T , accelerator Accelerator , baseImage string , requirementsFile string , numNodes , numProcPerNode int32 ) {
6383 test := With (t )
6484
6585 // Create a namespace
@@ -72,21 +92,19 @@ func runPyTorchDDPMultiNodeJob(t *testing.T, accelerator Accelerator, baseImage
7292 // Create PVC
7393 pvc := CreatePersistentVolumeClaim (test , namespace , "2Gi" , AccessModes (corev1 .ReadWriteMany ), StorageClassName (storageClass .Name ))
7494
75- // Create ConfigMap
95+ // Create ConfigMap with training scripts and requirements
7696 files := map [string ][]byte {
7797 "fashion_mnist.py" : readFile (test , "resources/fashion_mnist.py" ),
7898 "download_fashion_mnist.py" : readFile (test , "resources/download_fashion_mnist.py" ),
79- "requirements.txt" : readFile (test , "resources/requirements.txt" ),
99+ "requirements.txt" : readFile (test , requirementsFile ),
80100 }
81101 config := CreateConfigMap (test , namespace , files )
82102
83103 // Create TrainingRuntime with dataset-initializer
84104 trainingRuntime := createFashionMNISTTrainingRuntime (test , namespace , config .Name , pvc .Name , baseImage , accelerator , numProcPerNode )
85- defer deleteTrainingRuntime (test , namespace , trainingRuntime .Name )
86105
87106 // Create TrainJob
88107 trainJob := createFashionMNISTTrainJob (test , namespace , trainingRuntime .Name , accelerator , numNodes , numProcPerNode )
89- defer deleteTrainJob (test , namespace , trainJob .Name )
90108
91109 // Verify JobSet creation
92110 test .T ().Logf ("Verifying JobSet creation with replicated jobs..." )
@@ -111,13 +129,6 @@ func runPyTorchDDPMultiNodeJob(t *testing.T, accelerator Accelerator, baseImage
111129func createFashionMNISTTrainingRuntime (test Test , namespace , configMapName , pvcName , baseImage string , accelerator Accelerator , numProcPerNode int32 ) * trainerv1alpha1.TrainingRuntime {
112130 test .T ().Helper ()
113131
114- var backend string
115- if accelerator .IsGpu () {
116- backend = "nccl"
117- } else {
118- backend = "gloo"
119- }
120-
121132 trainingRuntime := & trainerv1alpha1.TrainingRuntime {
122133 ObjectMeta : metav1.ObjectMeta {
123134 GenerateName : "test-fashion-mnist-runtime-" ,
@@ -161,28 +172,41 @@ func createFashionMNISTTrainingRuntime(test Test, namespace, configMapName, pvcN
161172 echo " Dataset Initializer "
162173 echo "=========================================="
163174
164- echo "Installing dependencies from requirements.txt to ${LIB_PATH}..."
175+ # Install to local temp directory first, then copy to PVC to avoid Azure Files SMB cross-device link issues
176+ LOCAL_LIB=/tmp/pip_packages
177+ mkdir -p ${LOCAL_LIB}
165178 mkdir -p ${LIB_PATH}
166179
167- while IFS= read -r line || [[ -n "$line" ]]; do
180+ echo "Installing dependencies from requirements.txt ..."
181+ # Extract --extra-index-url if present in requirements.txt file
182+ EXTRA_INDEX=""
183+ if grep -q "^--extra-index-url" /mnt/scripts/requirements.txt; then
184+ EXTRA_INDEX=$(grep "^--extra-index-url" /mnt/scripts/requirements.txt)
185+ echo "Using: $EXTRA_INDEX"
186+ fi
168187
169- [[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && continue
188+ # Parse requirements file and install packages
189+ while IFS= read -r line || [[ -n "$line" ]]; do
190+ # Skip empty lines, comments, and pip options (like --extra-index-url)
191+ [[ -z "$line" || "$line" =~ ^[[:space:]]*# || "$line" =~ ^-- ]] && continue
170192
171193 # Check if line has "# no-deps" marker
172194 if [[ "$line" =~ "# no-deps" ]]; then
173195 pkg=$(echo "$line" | sed 's/[[:space:]]*#.*//')
174- echo "Installing $pkg (--no-deps) ..."
175- pip install --no-cache-dir --no-deps "$pkg" --target=${LIB_PATH}
196+ echo "Installing $pkg without dependencies ..."
197+ pip install --no-cache-dir --no-deps $EXTRA_INDEX "$pkg" --target=${LOCAL_LIB} --verbose
176198 else
177199 # Install with dependencies
178200 pkg=$(echo "$line" | sed 's/[[:space:]]*#.*//')
179- echo "Installing $pkg ( with deps) ..."
180- pip install --no-cache-dir "$pkg" --target=${LIB_PATH}
201+ echo "Installing $pkg with dependencies ..."
202+ pip install --no-cache-dir $EXTRA_INDEX "$pkg" --target=${LOCAL_LIB} --verbose
181203 fi
182204 done < /mnt/scripts/requirements.txt
183205
184206 echo ""
185- echo "Dependencies installed successfully ..."
207+ echo "Copying installed packages to ${LIB_PATH}..."
208+ cp -r ${LOCAL_LIB}/* ${LIB_PATH}/
209+ echo "Dependencies installed successfully!"
186210 ls -la ${LIB_PATH}/ | head -20
187211
188212 # Download dataset to shared volume
@@ -271,11 +295,10 @@ func createFashionMNISTTrainingRuntime(test Test, namespace, configMapName, pvcN
271295 ImagePullPolicy : corev1 .PullIfNotPresent ,
272296 Command : []string {"/bin/bash" , "-c" },
273297 Args : []string {
274- fmt . Sprintf ( `
298+ `
275299 set -e
276300
277301 echo "==================== Environment Info ===================="
278- echo "PyTorch Backend: %s"
279302 echo "Dataset Path: ${DATASET_PATH}"
280303 echo "==========================================================="
281304
@@ -309,7 +332,7 @@ func createFashionMNISTTrainingRuntime(test Test, namespace, configMapName, pvcN
309332
310333 echo ""
311334 echo "==================== Training Complete ===================="
312- ` , backend ),
335+ ` ,
313336 },
314337 Resources : buildResourceRequirements (accelerator , numProcPerNode ),
315338 VolumeMounts : []corev1.VolumeMount {
0 commit comments