Skip to content

Commit 5886134

Browse files
fix: cover numtiple Cuda/ROCm based test coverage
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 91cb578 commit 5886134

File tree

2 files changed

+119
-72
lines changed

2 files changed

+119
-72
lines changed

tests/trainer/kubeflow_sdk_test.go

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,32 +41,50 @@ func TestSftTrainingHubMultiNodeMultiGPU(t *testing.T) {
4141
sdktests.RunSftTrainingHubMultiGpuDistributedTraining(t)
4242
}
4343

44+
// CPU tests
4445
func TestRhaiTrainingProgressionCPU(t *testing.T) {
4546
Tags(t, Tier1)
46-
sdktests.RunRhaiFeaturesProgressionTest(t)
47+
sdktests.RunRhaiFeaturesProgressionTest(t, support.CPU)
4748
}
4849

49-
func TestRhaiTrainingProgressionGPU(t *testing.T) {
50-
Tags(t, KftoCuda)
51-
sdktests.RunRhaiFeaturesProgressionTestGPU(t)
50+
func TestRhaiJitCheckpointingCPU(t *testing.T) {
51+
Tags(t, Tier1)
52+
sdktests.RunRhaiFeaturesCheckpointTest(t, support.CPU)
5253
}
5354

54-
func TestRhaiJitCheckpointingCPU(t *testing.T) {
55+
func TestRhaiFeaturesCPU(t *testing.T) {
5556
Tags(t, Tier1)
56-
sdktests.RunRhaiFeaturesCheckpointTest(t)
57+
sdktests.RunRhaiFeaturesAllTest(t, support.CPU)
5758
}
5859

59-
func TestRhaiJitCheckpointingGPU(t *testing.T) {
60-
Tags(t, KftoCuda)
61-
sdktests.RunRhaiFeaturesCheckpointTestGPU(t)
60+
// CUDA (NVIDIA) GPU tests - 2 nodes, 1 GPU each
61+
func TestRhaiTrainingProgressionCuda(t *testing.T) {
62+
Tags(t, KftoCuda, MultiNodeGpu(2, support.NVIDIA))
63+
sdktests.RunRhaiFeaturesProgressionTest(t, support.NVIDIA)
6264
}
6365

64-
func TestRhaiFeaturesCPU(t *testing.T) {
65-
Tags(t, Tier1)
66-
sdktests.RunRhaiFeaturesAllTest(t)
66+
func TestRhaiJitCheckpointingCuda(t *testing.T) {
67+
Tags(t, KftoCuda, MultiNodeGpu(2, support.NVIDIA))
68+
sdktests.RunRhaiFeaturesCheckpointTest(t, support.NVIDIA)
69+
}
70+
71+
func TestRhaiFeaturesCuda(t *testing.T) {
72+
Tags(t, KftoCuda, MultiNodeGpu(2, support.NVIDIA))
73+
sdktests.RunRhaiFeaturesAllTest(t, support.NVIDIA)
74+
}
75+
76+
// ROCm (AMD) GPU tests - 2 nodes, 1 GPU each
77+
func TestRhaiTrainingProgressionRocm(t *testing.T) {
78+
Tags(t, KftoRocm, MultiNodeGpu(2, support.AMD))
79+
sdktests.RunRhaiFeaturesProgressionTest(t, support.AMD)
80+
}
81+
82+
func TestRhaiJitCheckpointingRocm(t *testing.T) {
83+
Tags(t, KftoRocm, MultiNodeGpu(2, support.AMD))
84+
sdktests.RunRhaiFeaturesCheckpointTest(t, support.AMD)
6785
}
6886

69-
func TestRhaiFeaturesGPU(t *testing.T) {
70-
Tags(t, KftoCuda)
71-
sdktests.RunRhaiFeaturesAllTestGPU(t)
87+
func TestRhaiFeaturesRocm(t *testing.T) {
88+
Tags(t, KftoRocm, MultiNodeGpu(2, support.AMD))
89+
sdktests.RunRhaiFeaturesAllTest(t, support.AMD)
7290
}

tests/trainer/sdk_tests/rhai_features_tests.go

Lines changed: 86 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -56,75 +56,39 @@ type RhaiFeatureConfig struct {
5656
Accelerator Accelerator // CPU, NVIDIA, or AMD
5757
}
5858

59-
// RunRhaiFeaturesTest runs the e2e test for RHAI features with progression tracking only (CPU)
60-
func RunRhaiFeaturesProgressionTest(t *testing.T) {
59+
// RunRhaiFeaturesProgressionTest runs the e2e test for RHAI features with progression tracking
60+
func RunRhaiFeaturesProgressionTest(t *testing.T, accelerator Accelerator) {
6161
runRhaiFeaturesTestWithConfig(t, RhaiFeatureConfig{
6262
EnableProgressionTracking: true,
6363
EnableJitCheckpoint: false,
6464
CheckpointOutputDir: "/workspace/checkpoints",
6565
CheckpointSaveStrategy: "epoch",
6666
CheckpointSaveTotalLimit: "3",
67-
Accelerator: CPU,
67+
Accelerator: accelerator,
6868
})
6969
}
7070

71-
// RunRhaiFeaturesProgressionTestGPU runs the e2e test for RHAI features with progression tracking (GPU)
72-
func RunRhaiFeaturesProgressionTestGPU(t *testing.T) {
73-
runRhaiFeaturesTestWithConfig(t, RhaiFeatureConfig{
74-
EnableProgressionTracking: true,
75-
EnableJitCheckpoint: false,
76-
CheckpointOutputDir: "/workspace/checkpoints",
77-
CheckpointSaveStrategy: "epoch",
78-
CheckpointSaveTotalLimit: "3",
79-
Accelerator: NVIDIA,
80-
})
81-
}
82-
83-
// RunRhaiFeaturesCheckpointTest runs the e2e test for RHAI features with checkpointing only (CPU)
84-
func RunRhaiFeaturesCheckpointTest(t *testing.T) {
85-
runRhaiFeaturesTestWithConfig(t, RhaiFeatureConfig{
86-
EnableProgressionTracking: false,
87-
EnableJitCheckpoint: true,
88-
CheckpointOutputDir: "/workspace/checkpoints",
89-
CheckpointSaveStrategy: "epoch",
90-
CheckpointSaveTotalLimit: "3",
91-
Accelerator: CPU,
92-
})
93-
}
94-
95-
// RunRhaiFeaturesCheckpointTestGPU runs the e2e test for RHAI features with checkpointing (GPU)
96-
func RunRhaiFeaturesCheckpointTestGPU(t *testing.T) {
71+
// RunRhaiFeaturesCheckpointTest runs the e2e test for RHAI features with checkpointing
72+
func RunRhaiFeaturesCheckpointTest(t *testing.T, accelerator Accelerator) {
9773
runRhaiFeaturesTestWithConfig(t, RhaiFeatureConfig{
9874
EnableProgressionTracking: false,
9975
EnableJitCheckpoint: true,
10076
CheckpointOutputDir: "/workspace/checkpoints",
10177
CheckpointSaveStrategy: "epoch",
10278
CheckpointSaveTotalLimit: "3",
103-
Accelerator: NVIDIA,
79+
Accelerator: accelerator,
10480
})
10581
}
10682

107-
// RunRhaiFeaturesAllTest runs the e2e test for RHAI features with both progression tracking and checkpointing (CPU)
108-
func RunRhaiFeaturesAllTest(t *testing.T) {
83+
// RunRhaiFeaturesAllTest runs the e2e test for RHAI features with both progression tracking and checkpointing
84+
func RunRhaiFeaturesAllTest(t *testing.T, accelerator Accelerator) {
10985
runRhaiFeaturesTestWithConfig(t, RhaiFeatureConfig{
11086
EnableProgressionTracking: true,
11187
EnableJitCheckpoint: true,
11288
CheckpointOutputDir: "/workspace/checkpoints",
11389
CheckpointSaveStrategy: "epoch",
11490
CheckpointSaveTotalLimit: "3",
115-
Accelerator: CPU,
116-
})
117-
}
118-
119-
// RunRhaiFeaturesAllTestGPU runs the e2e test for RHAI features with both features (GPU)
120-
func RunRhaiFeaturesAllTestGPU(t *testing.T) {
121-
runRhaiFeaturesTestWithConfig(t, RhaiFeatureConfig{
122-
EnableProgressionTracking: true,
123-
EnableJitCheckpoint: true,
124-
CheckpointOutputDir: "/workspace/checkpoints",
125-
CheckpointSaveStrategy: "epoch",
126-
CheckpointSaveTotalLimit: "3",
127-
Accelerator: NVIDIA,
91+
Accelerator: accelerator,
12892
})
12993
}
13094

@@ -186,6 +150,7 @@ func runRhaiFeaturesTestWithConfig(t *testing.T, config RhaiFeatureConfig) {
186150

187151
shellCmd := fmt.Sprintf(
188152
"set -e; "+
153+
"export IPYTHONDIR='/tmp/.ipython'; "+
189154
"export OPENSHIFT_API_URL='%s'; "+
190155
"export NOTEBOOK_TOKEN='%s'; "+
191156
"export NOTEBOOK_NAMESPACE='%s'; "+
@@ -410,7 +375,7 @@ func verifyCheckpoints(test Test, namespace, trainJobName, checkpointDir string)
410375
}
411376
}
412377
return runningCount
413-
}, TestTimeoutMedium, 5*time.Second).Should(BeNumerically(">", 0), "At least one training pod should be running")
378+
}, TestTimeoutLong, 5*time.Second).Should(BeNumerically(">", 0), "At least one training pod should be running")
414379
test.T().Log("Training pods are running")
415380

416381
// Wait for 1st epoch to complete before suspending
@@ -491,19 +456,39 @@ func verifyCheckpoints(test Test, namespace, trainJobName, checkpointDir string)
491456
}, TestTimeoutMedium, 5*time.Second).Should(BeNumerically(">", 0), "Training pods should start after resume")
492457
test.T().Log("New training pods started")
493458

494-
// Wait for training to make progress after resume, then capture progress
495-
test.T().Log("Waiting for training to make progress after resume...")
496-
test.Eventually(func() int {
497-
return getProgressPercentage(test, namespace, trainJobName)
498-
}, TestTimeoutMedium, 5*time.Second).Should(BeNumerically(">", 0), "Training should show progress after resume")
459+
// Capture pre-suspend state for comparison
460+
preSuspendEpoch := getCurrentEpoch(test, namespace, trainJobName)
461+
preSuspendLastUpdated := getLastUpdatedTime(test, namespace, trainJobName)
462+
test.T().Logf("Pre-suspend state: epoch=%.2f, progress=%d%%, lastUpdatedTime=%s", preSuspendEpoch, preSuspendProgress, preSuspendLastUpdated)
463+
464+
// Wait for new pods to update the annotation (lastUpdatedTime must change)
465+
// This ensures we're reading fresh data from resumed pods, not stale pre-suspend data
466+
test.T().Log("Waiting for resumed pods to report fresh progress...")
467+
test.Eventually(func() bool {
468+
newLastUpdated := getLastUpdatedTime(test, namespace, trainJobName)
469+
if newLastUpdated != "" && newLastUpdated != preSuspendLastUpdated {
470+
test.T().Logf("Fresh progress detected: lastUpdatedTime=%s", newLastUpdated)
471+
return true
472+
}
473+
return false
474+
}, TestTimeoutMedium, 5*time.Second).Should(BeTrue(), "Annotation should be updated by resumed pods (lastUpdatedTime should change)")
475+
476+
// Now read the fresh values reported by resumed pods
477+
initialResumeEpoch := getCurrentEpoch(test, namespace, trainJobName)
478+
initialResumeProgress := getProgressPercentage(test, namespace, trainJobName)
479+
test.T().Logf("Initial after resume: epoch=%.2f, progress=%d%%", initialResumeEpoch, initialResumeProgress)
480+
481+
// Epoch after resume should be >= pre-suspend (checkpoint preserves state)
482+
test.Expect(initialResumeEpoch).To(BeNumerically(">=", preSuspendEpoch),
483+
fmt.Sprintf("Initial epoch after resume (%.2f) should be >= pre-suspend epoch (%.2f) - checkpoint should preserve training state", initialResumeEpoch, preSuspendEpoch))
499484

500-
postResumeProgress := getProgressPercentage(test, namespace, trainJobName)
501-
test.T().Logf("Progress AFTER resume: %d%%", postResumeProgress)
485+
// Progress percentage should also be preserved
486+
test.Expect(initialResumeProgress).To(BeNumerically(">=", preSuspendProgress),
487+
fmt.Sprintf("Initial progress after resume (%d%%) should be >= pre-suspend progress (%d%%) - checkpoint should preserve progress", initialResumeProgress, preSuspendProgress))
502488

503-
// Verify progress after resume is >= progress before suspension (checkpoint worked)
504-
test.T().Logf("Checkpoint validation: pre-suspend=%d%%, post-resume=%d%%", preSuspendProgress, postResumeProgress)
505-
test.Expect(postResumeProgress).To(BeNumerically(">=", preSuspendProgress),
506-
fmt.Sprintf("Progress after resume (%d%%) should be >= before suspend (%d%%) - checkpoint should preserve progress", postResumeProgress, preSuspendProgress))
489+
test.T().Log("Checkpoint verification: Training resumed from correct epoch and progress")
490+
491+
test.T().Log("Waiting for resumed training to complete...")
507492

508493
// Step 5: Wait for training to complete or fail (to get final state)
509494
test.T().Log("Step 5: Waiting for training to complete after resume...")
@@ -569,6 +554,50 @@ func getProgressPercentage(test Test, namespace, trainJobName string) int {
569554
return 0
570555
}
571556

557+
// getCurrentEpoch extracts currentEpoch from trainerStatus annotation
558+
func getCurrentEpoch(test Test, namespace, trainJobName string) float64 {
559+
test.T().Helper()
560+
561+
trainJob := TrainJob(test, namespace, trainJobName)(test)
562+
annotations := trainJob.GetAnnotations()
563+
statusJSON, ok := annotations[annotationTrainerStatus]
564+
if !ok {
565+
return 0
566+
}
567+
568+
var status map[string]interface{}
569+
if err := json.Unmarshal([]byte(statusJSON), &status); err != nil {
570+
return 0
571+
}
572+
573+
if epoch, ok := status["currentEpoch"].(float64); ok {
574+
return epoch
575+
}
576+
return 0
577+
}
578+
579+
// getLastUpdatedTime extracts lastUpdatedTime from trainerStatus annotation
580+
func getLastUpdatedTime(test Test, namespace, trainJobName string) string {
581+
test.T().Helper()
582+
583+
trainJob := TrainJob(test, namespace, trainJobName)(test)
584+
annotations := trainJob.GetAnnotations()
585+
statusJSON, ok := annotations[annotationTrainerStatus]
586+
if !ok {
587+
return ""
588+
}
589+
590+
var status map[string]interface{}
591+
if err := json.Unmarshal([]byte(statusJSON), &status); err != nil {
592+
return ""
593+
}
594+
595+
if lastUpdated, ok := status["lastUpdatedTime"].(string); ok {
596+
return lastUpdated
597+
}
598+
return ""
599+
}
600+
572601
// suspendTrainJob toggles the suspend state of a TrainJob using patch
573602
func suspendTrainJob(test Test, namespace, trainJobName string, suspend bool) {
574603
test.T().Helper()

0 commit comments

Comments
 (0)