@@ -30,17 +30,26 @@ import (
3030 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3131)
3232
33- func TestPyTorchJobMnistCpu (t * testing.T ) {
33+ func TestPyTorchJobMnistMultiNodeCpu (t * testing.T ) {
3434 runKFTOPyTorchMnistJob (t , 0 , 2 , "" , GetCudaTrainingImage (), "resources/requirements.txt" )
3535}
36- func TestPyTorchJobMnistWithCuda (t * testing.T ) {
36+
37+ func TestPyTorchJobMnistMultiNodeWithCuda (t * testing.T ) {
3738 runKFTOPyTorchMnistJob (t , 1 , 1 , "nvidia.com/gpu" , GetCudaTrainingImage (), "resources/requirements.txt" )
3839}
3940
40- func TestPyTorchJobMnistWithROCm (t * testing.T ) {
41+ func TestPyTorchJobMnistMultiNodeWithROCm (t * testing.T ) {
4142 runKFTOPyTorchMnistJob (t , 1 , 1 , "amd.com/gpu" , GetROCmTrainingImage (), "resources/requirements-rocm.txt" )
4243}
4344
45+ func TestPyTorchJobMnistMultiNodeMultiGpuWithCuda (t * testing.T ) {
46+ runKFTOPyTorchMnistJob (t , 2 , 1 , "nvidia.com/gpu" , GetCudaTrainingImage (), "resources/requirements.txt" )
47+ }
48+
49+ func TestPyTorchJobMnistMultiNodeMultiGpuWithROCm (t * testing.T ) {
50+ runKFTOPyTorchMnistJob (t , 2 , 1 , "amd.com/gpu" , GetROCmTrainingImage (), "resources/requirements-rocm.txt" )
51+ }
52+
4453func runKFTOPyTorchMnistJob (t * testing.T , numGpus int , workerReplicas int , gpuLabel string , image string , requirementsFile string ) {
4554 test := With (t )
4655
0 commit comments