Skip to content

Commit 8440348

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 8b7bea5 + a439296 commit 8440348

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tests/kfto/kfto_mnist_training_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,26 @@ import (
3030
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3131
)
3232

33-
func TestPyTorchJobMnistCpu(t *testing.T) {
33+
func TestPyTorchJobMnistMultiNodeCpu(t *testing.T) {
3434
runKFTOPyTorchMnistJob(t, 0, 2, "", GetCudaTrainingImage(), "resources/requirements.txt")
3535
}
36-
func TestPyTorchJobMnistWithCuda(t *testing.T) {
36+
37+
func TestPyTorchJobMnistMultiNodeWithCuda(t *testing.T) {
3738
runKFTOPyTorchMnistJob(t, 1, 1, "nvidia.com/gpu", GetCudaTrainingImage(), "resources/requirements.txt")
3839
}
3940

40-
func TestPyTorchJobMnistWithROCm(t *testing.T) {
41+
func TestPyTorchJobMnistMultiNodeWithROCm(t *testing.T) {
4142
runKFTOPyTorchMnistJob(t, 1, 1, "amd.com/gpu", GetROCmTrainingImage(), "resources/requirements-rocm.txt")
4243
}
4344

45+
func TestPyTorchJobMnistMultiNodeMultiGpuWithCuda(t *testing.T) {
46+
runKFTOPyTorchMnistJob(t, 2, 1, "nvidia.com/gpu", GetCudaTrainingImage(), "resources/requirements.txt")
47+
}
48+
49+
func TestPyTorchJobMnistMultiNodeMultiGpuWithROCm(t *testing.T) {
50+
runKFTOPyTorchMnistJob(t, 2, 1, "amd.com/gpu", GetROCmTrainingImage(), "resources/requirements-rocm.txt")
51+
}
52+
4453
func runKFTOPyTorchMnistJob(t *testing.T, numGpus int, workerReplicas int, gpuLabel string, image string, requirementsFile string) {
4554
test := With(t)
4655

0 commit comments

Comments
 (0)