@@ -30,17 +30,26 @@ import (
30
30
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
31
31
)
32
32
33
- func TestPyTorchJobMnistCpu (t * testing.T ) {
33
+ func TestPyTorchJobMnistMultiNodeCpu (t * testing.T ) {
34
34
runKFTOPyTorchMnistJob (t , 0 , 2 , "" , GetCudaTrainingImage (), "resources/requirements.txt" )
35
35
}
36
- func TestPyTorchJobMnistWithCuda (t * testing.T ) {
36
+
37
+ func TestPyTorchJobMnistMultiNodeWithCuda (t * testing.T ) {
37
38
runKFTOPyTorchMnistJob (t , 1 , 1 , "nvidia.com/gpu" , GetCudaTrainingImage (), "resources/requirements.txt" )
38
39
}
39
40
40
- func TestPyTorchJobMnistWithROCm (t * testing.T ) {
41
+ func TestPyTorchJobMnistMultiNodeWithROCm (t * testing.T ) {
41
42
runKFTOPyTorchMnistJob (t , 1 , 1 , "amd.com/gpu" , GetROCmTrainingImage (), "resources/requirements-rocm.txt" )
42
43
}
43
44
45
+ func TestPyTorchJobMnistMultiNodeMultiGpuWithCuda (t * testing.T ) {
46
+ runKFTOPyTorchMnistJob (t , 2 , 1 , "nvidia.com/gpu" , GetCudaTrainingImage (), "resources/requirements.txt" )
47
+ }
48
+
49
+ func TestPyTorchJobMnistMultiNodeMultiGpuWithROCm (t * testing.T ) {
50
+ runKFTOPyTorchMnistJob (t , 2 , 1 , "amd.com/gpu" , GetROCmTrainingImage (), "resources/requirements-rocm.txt" )
51
+ }
52
+
44
53
func runKFTOPyTorchMnistJob (t * testing.T , numGpus int , workerReplicas int , gpuLabel string , image string , requirementsFile string ) {
45
54
test := With (t )
46
55
0 commit comments