|
| 1 | +/* |
| 2 | +Copyright 2025. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +*/ |
| 16 | + |
| 17 | +package trainer |
| 18 | + |
| 19 | +import ( |
| 20 | + "testing" |
| 21 | + |
| 22 | + trainerv1alpha1 "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" |
| 23 | + . "github.com/onsi/gomega" |
| 24 | + |
| 25 | + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" |
| 26 | + |
| 27 | + . "github.com/opendatahub-io/distributed-workloads/tests/common" |
| 28 | + . "github.com/opendatahub-io/distributed-workloads/tests/common/support" |
| 29 | +) |
| 30 | + |
| 31 | +type ClusterTrainingRuntime struct { |
| 32 | + Name string |
| 33 | + ODHImage string |
| 34 | + RHOAIImage string |
| 35 | +} |
| 36 | + |
| 37 | +var expectedRuntimes = []ClusterTrainingRuntime{ |
| 38 | + {Name: "torch-distributed", ODHImage: "training:py312-cuda128-torch280", RHOAIImage: "odh-training-cuda128-torch28-py312-rhel9"}, |
| 39 | + {Name: "torch-distributed-rocm", ODHImage: "training:py312-rocm64-torch280", RHOAIImage: "odh-training-rocm64-torch28-py312-rhel9"}, |
| 40 | + {Name: "torch-distributed-th03-cuda128-torch28-py312", ODHImage: "training:py312-cuda128-torch280", RHOAIImage: "odh-training-cuda128-torch28-py312-rhel9"}, |
| 41 | + {Name: "training-hub", ODHImage: "training:py312-cuda128-torch280", RHOAIImage: "odh-training-cuda128-torch28-py312-rhel9"}, |
| 42 | + {Name: "training-hub03-cuda128-torch28-py312", ODHImage: "training:py312-cuda128-torch280", RHOAIImage: "odh-training-cuda128-torch28-py312-rhel9"}, |
| 43 | +} |
| 44 | + |
| 45 | +// defaultClusterTrainingRuntime is used across integration tests |
| 46 | +var defaultClusterTrainingRuntime = expectedRuntimes[0].Name |
| 47 | + |
| 48 | +func TestDefaultClusterTrainingRuntimes(t *testing.T) { |
| 49 | + Tags(t, Smoke) |
| 50 | + test := With(t) |
| 51 | + |
| 52 | + // Determine registry based on ODH namespace |
| 53 | + namespace := GetOpenDataHubNamespace(test) |
| 54 | + registryName := GetExpectedRegistry(test, namespace) |
| 55 | + |
| 56 | + // Build a map of expected runtimes for quick lookup |
| 57 | + expectedRuntimeMap := make(map[string]ClusterTrainingRuntime) |
| 58 | + for _, runtime := range expectedRuntimes { |
| 59 | + expectedRuntimeMap[runtime.Name] = runtime |
| 60 | + } |
| 61 | + |
| 62 | + // List all ClusterTrainingRuntimes from the cluster |
| 63 | + runtimeList, err := test.Client().Trainer().TrainerV1alpha1().ClusterTrainingRuntimes().List( |
| 64 | + test.Ctx(), |
| 65 | + metav1.ListOptions{}, |
| 66 | + ) |
| 67 | + test.Expect(err).NotTo(HaveOccurred(), "Failed to list ClusterTrainingRuntimes") |
| 68 | + |
| 69 | + // Track unexpected runtimes and found expected runtimes |
| 70 | + var unexpectedRuntimes []string |
| 71 | + foundRuntimes := make(map[string]bool) |
| 72 | + |
| 73 | + // Iterate over runtimes present in the cluster |
| 74 | + for _, runtime := range runtimeList.Items { |
| 75 | + expectedRuntime, found := expectedRuntimeMap[runtime.Name] |
| 76 | + if !found { |
| 77 | + unexpectedRuntimes = append(unexpectedRuntimes, runtime.Name) |
| 78 | + test.T().Logf("WARNING: Unexpected ClusterTrainingRuntime '%s' found", runtime.Name) |
| 79 | + continue |
| 80 | + } |
| 81 | + |
| 82 | + foundRuntimes[runtime.Name] = true |
| 83 | + test.T().Logf("ClusterTrainingRuntime '%s' is present", runtime.Name) |
| 84 | + |
| 85 | + // Find container image from the runtime spec |
| 86 | + var foundImage string |
| 87 | + for _, replicatedJob := range runtime.Spec.Template.Spec.ReplicatedJobs { |
| 88 | + for _, container := range replicatedJob.Template.Spec.Template.Spec.Containers { |
| 89 | + if container.Image != "" { |
| 90 | + foundImage = container.Image |
| 91 | + break |
| 92 | + } |
| 93 | + } |
| 94 | + if foundImage != "" { |
| 95 | + break |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + test.Expect(foundImage).NotTo(BeEmpty(), "No container image found in ClusterTrainingRuntime %s", runtime.Name) |
| 100 | + test.T().Logf("Image referred in ClusterTrainingRuntime is %s", foundImage) |
| 101 | + |
| 102 | + // Verify image based on environment |
| 103 | + var expectedImage string |
| 104 | + switch registryName { |
| 105 | + case "registry.redhat.io": |
| 106 | + expectedImage = registryName + "/rhoai/" + expectedRuntime.RHOAIImage |
| 107 | + case "quay.io": |
| 108 | + expectedImage = registryName + "/modh/" + expectedRuntime.ODHImage |
| 109 | + default: |
| 110 | + test.T().Fatalf("Unexpected registry: %s", registryName) |
| 111 | + } |
| 112 | + |
| 113 | + test.Expect(foundImage).To(ContainSubstring(expectedImage), |
| 114 | + "Image %s should contain %s", foundImage, expectedImage) |
| 115 | + test.T().Logf("ClusterTrainingRuntime '%s' uses expected image: %s", expectedRuntime.Name, expectedImage) |
| 116 | + } |
| 117 | + |
| 118 | + // Verify all expected runtimes are present |
| 119 | + var missingRuntimes []string |
| 120 | + for _, expected := range expectedRuntimes { |
| 121 | + if !foundRuntimes[expected.Name] { |
| 122 | + missingRuntimes = append(missingRuntimes, expected.Name) |
| 123 | + } |
| 124 | + } |
| 125 | + |
| 126 | + // Fail if any unexpected runtimes found |
| 127 | + test.Expect(unexpectedRuntimes).To(BeEmpty(), |
| 128 | + "Unexpected ClusterTrainingRuntimes found: %v. Please update expectedRuntimes list.", unexpectedRuntimes) |
| 129 | + |
| 130 | + // Fail if any expected runtimes missing |
| 131 | + test.Expect(missingRuntimes).To(BeEmpty(), |
| 132 | + "Missing expected ClusterTrainingRuntimes: %v. These runtimes should be present on the cluster.", missingRuntimes) |
| 133 | + |
| 134 | + test.T().Log("All ClusterTrainingRuntimes verified successfully!") |
| 135 | +} |
| 136 | + |
| 137 | +func TestRunTrainJobWithDefaultClusterTrainingRuntimes(t *testing.T) { |
| 138 | + Tags(t, Sanity) |
| 139 | + test := With(t) |
| 140 | + |
| 141 | + for _, runtime := range expectedRuntimes { |
| 142 | + test.T().Logf("Running TrainJob with ClusterTrainingRuntime: %s", runtime.Name) |
| 143 | + |
| 144 | + // Create a namespace |
| 145 | + namespace := test.NewTestNamespace().Name |
| 146 | + |
| 147 | + // Create TrainJob |
| 148 | + trainJob := createTrainJob(test, namespace, runtime.Name) |
| 149 | + |
| 150 | + // Wait for TrainJob completion |
| 151 | + test.Eventually(TrainJob(test, namespace, trainJob.Name), TestTimeoutLong). |
| 152 | + Should(WithTransform(TrainJobConditionComplete, Equal(metav1.ConditionTrue))) |
| 153 | + |
| 154 | + test.T().Logf("TrainJob with ClusterTrainingRuntime '%s' completed successfully", runtime.Name) |
| 155 | + } |
| 156 | + |
| 157 | + test.T().Log("All TrainJobs with expected ClusterTrainingRuntimes completed successfully !!!") |
| 158 | +} |
| 159 | + |
| 160 | +func createTrainJob(test Test, namespace, runtimeName string) *trainerv1alpha1.TrainJob { |
| 161 | + test.T().Helper() |
| 162 | + |
| 163 | + trainJob := &trainerv1alpha1.TrainJob{ |
| 164 | + ObjectMeta: metav1.ObjectMeta{ |
| 165 | + GenerateName: "test-trainjob-", |
| 166 | + Namespace: namespace, |
| 167 | + }, |
| 168 | + Spec: trainerv1alpha1.TrainJobSpec{ |
| 169 | + RuntimeRef: trainerv1alpha1.RuntimeRef{ |
| 170 | + Name: runtimeName, |
| 171 | + }, |
| 172 | + Trainer: &trainerv1alpha1.Trainer{ |
| 173 | + Command: []string{ |
| 174 | + "python", |
| 175 | + "-c", |
| 176 | + "import torch; print(f'PyTorch version: {torch.__version__}'); print('Training completed successfully')", |
| 177 | + }, |
| 178 | + }, |
| 179 | + }, |
| 180 | + } |
| 181 | + |
| 182 | + createTrainJob, err := test.Client().Trainer().TrainerV1alpha1().TrainJobs(namespace).Create( |
| 183 | + test.Ctx(), |
| 184 | + trainJob, |
| 185 | + metav1.CreateOptions{}, |
| 186 | + ) |
| 187 | + test.Expect(err).NotTo(HaveOccurred(), "Failed to create TrainJob") |
| 188 | + test.T().Logf("Created TrainJob %s/%s successfully", createTrainJob.Namespace, createTrainJob.Name) |
| 189 | + |
| 190 | + return createTrainJob |
| 191 | +} |
| 192 | + |
| 193 | +func deleteTrainJob(test Test, namespace, name string) { |
| 194 | + test.T().Helper() |
| 195 | + |
| 196 | + err := test.Client().Trainer().TrainerV1alpha1().TrainJobs(namespace).Delete( |
| 197 | + test.Ctx(), |
| 198 | + name, |
| 199 | + metav1.DeleteOptions{}, |
| 200 | + ) |
| 201 | + if err != nil { |
| 202 | + test.T().Logf("Warning: Failed to delete TrainJob %s/%s: %v", namespace, name, err) |
| 203 | + } else { |
| 204 | + test.T().Logf("Deleted TrainJob %s/%s successfully", namespace, name) |
| 205 | + } |
| 206 | +} |
0 commit comments