@@ -35,31 +35,52 @@ import (
3535
3636func TestPyTorchJobMnistMultiNodeSingleCpu (t * testing.T ) {
3737 Tags (t , Sanity , MultiNode (3 ))
38- runKFTOPyTorchMnistJob (t , CPU , GetCudaTrainingImage (), "resources/requirements.txt" , 2 , 1 )
38+ runKFTOPyTorchMnistJob (t , CPU , GetTrainingCudaPyTorch251Image (), "resources/requirements.txt" , 2 , 1 )
3939}
40+
4041func TestPyTorchJobMnistMultiNodeMultiCpu (t * testing.T ) {
4142 Tags (t , Tier1 , MultiNode (3 ))
42- runKFTOPyTorchMnistJob (t , CPU , GetCudaTrainingImage (), "resources/requirements.txt" , 2 , 2 )
43+ runKFTOPyTorchMnistJob (t , CPU , GetTrainingCudaPyTorch251Image (), "resources/requirements.txt" , 2 , 2 )
44+ }
45+
46+ func TestPyTorchJobMnistMultiNodeSingleGpuWithCudaPyTorch241 (t * testing.T ) {
47+ Tags (t , KftoCuda )
48+ runKFTOPyTorchMnistJob (t , NVIDIA , GetTrainingCudaPyTorch241Image (), "resources/requirements.txt" , 1 , 1 )
49+ }
50+
51+ func TestPyTorchJobMnistMultiNodeSingleGpuWithCudaPyTorch251 (t * testing.T ) {
52+ Tags (t , KftoCuda )
53+ runKFTOPyTorchMnistJob (t , NVIDIA , GetTrainingCudaPyTorch251Image (), "resources/requirements.txt" , 1 , 1 )
4354}
4455
45- func TestPyTorchJobMnistMultiNodeSingleGpuWithCuda (t * testing.T ) {
56+ func TestPyTorchJobMnistMultiNodeMultiGpuWithCudaPyTorch241 (t * testing.T ) {
4657 Tags (t , KftoCuda )
47- runKFTOPyTorchMnistJob (t , NVIDIA , GetCudaTrainingImage (), "resources/requirements.txt" , 1 , 1 )
58+ runKFTOPyTorchMnistJob (t , NVIDIA , GetTrainingCudaPyTorch241Image (), "resources/requirements.txt" , 1 , 2 )
4859}
4960
50- func TestPyTorchJobMnistMultiNodeMultiGpuWithCuda (t * testing.T ) {
61+ func TestPyTorchJobMnistMultiNodeMultiGpuWithCudaPyTorch251 (t * testing.T ) {
5162 Tags (t , KftoCuda )
52- runKFTOPyTorchMnistJob (t , NVIDIA , GetCudaTrainingImage (), "resources/requirements.txt" , 1 , 2 )
63+ runKFTOPyTorchMnistJob (t , NVIDIA , GetTrainingCudaPyTorch251Image (), "resources/requirements.txt" , 1 , 2 )
64+ }
65+
66+ func TestPyTorchJobMnistMultiNodeSingleGpuWithROCmPyTorch241 (t * testing.T ) {
67+ Tags (t , KftoRocm )
68+ runKFTOPyTorchMnistJob (t , AMD , GetTrainingROCmPyTorch241Image (), "resources/requirements-rocm.txt" , 1 , 1 )
69+ }
70+
71+ func TestPyTorchJobMnistMultiNodeSingleGpuWithROCmPyTorch251 (t * testing.T ) {
72+ Tags (t , KftoRocm )
73+ runKFTOPyTorchMnistJob (t , AMD , GetTrainingROCmPyTorch251Image (), "resources/requirements-rocm.txt" , 1 , 1 )
5374}
5475
55- func TestPyTorchJobMnistMultiNodeSingleGpuWithROCm (t * testing.T ) {
76+ func TestPyTorchJobMnistMultiNodeMultiGpuWithROCmPyTorch241 (t * testing.T ) {
5677 Tags (t , KftoRocm )
57- runKFTOPyTorchMnistJob (t , AMD , GetROCmTrainingImage (), "resources/requirements-rocm.txt" , 1 , 1 )
78+ runKFTOPyTorchMnistJob (t , AMD , GetTrainingROCmPyTorch241Image (), "resources/requirements-rocm.txt" , 1 , 2 )
5879}
5980
60- func TestPyTorchJobMnistMultiNodeMultiGpuWithROCm (t * testing.T ) {
81+ func TestPyTorchJobMnistMultiNodeMultiGpuWithROCmPyTorch251 (t * testing.T ) {
6182 Tags (t , KftoRocm )
62- runKFTOPyTorchMnistJob (t , AMD , GetROCmTrainingImage (), "resources/requirements-rocm.txt" , 1 , 2 )
83+ runKFTOPyTorchMnistJob (t , AMD , GetTrainingROCmPyTorch251Image (), "resources/requirements-rocm.txt" , 1 , 2 )
6384}
6485
6586func runKFTOPyTorchMnistJob (t * testing.T , accelerator Accelerator , image string , requirementsFile string , workerReplicas , numProcPerNode int ) {
0 commit comments