@@ -56,75 +56,39 @@ type RhaiFeatureConfig struct {
5656 Accelerator Accelerator // CPU, NVIDIA, or AMD
5757}
5858
59- // RunRhaiFeaturesTest runs the e2e test for RHAI features with progression tracking only (CPU)
60- func RunRhaiFeaturesProgressionTest (t * testing.T ) {
59+ // RunRhaiFeaturesProgressionTest runs the e2e test for RHAI features with progression tracking
60+ func RunRhaiFeaturesProgressionTest (t * testing.T , accelerator Accelerator ) {
6161 runRhaiFeaturesTestWithConfig (t , RhaiFeatureConfig {
6262 EnableProgressionTracking : true ,
6363 EnableJitCheckpoint : false ,
6464 CheckpointOutputDir : "/workspace/checkpoints" ,
6565 CheckpointSaveStrategy : "epoch" ,
6666 CheckpointSaveTotalLimit : "3" ,
67- Accelerator : CPU ,
67+ Accelerator : accelerator ,
6868 })
6969}
7070
71- // RunRhaiFeaturesProgressionTestGPU runs the e2e test for RHAI features with progression tracking (GPU)
72- func RunRhaiFeaturesProgressionTestGPU (t * testing.T ) {
73- runRhaiFeaturesTestWithConfig (t , RhaiFeatureConfig {
74- EnableProgressionTracking : true ,
75- EnableJitCheckpoint : false ,
76- CheckpointOutputDir : "/workspace/checkpoints" ,
77- CheckpointSaveStrategy : "epoch" ,
78- CheckpointSaveTotalLimit : "3" ,
79- Accelerator : NVIDIA ,
80- })
81- }
82-
83- // RunRhaiFeaturesCheckpointTest runs the e2e test for RHAI features with checkpointing only (CPU)
84- func RunRhaiFeaturesCheckpointTest (t * testing.T ) {
85- runRhaiFeaturesTestWithConfig (t , RhaiFeatureConfig {
86- EnableProgressionTracking : false ,
87- EnableJitCheckpoint : true ,
88- CheckpointOutputDir : "/workspace/checkpoints" ,
89- CheckpointSaveStrategy : "epoch" ,
90- CheckpointSaveTotalLimit : "3" ,
91- Accelerator : CPU ,
92- })
93- }
94-
95- // RunRhaiFeaturesCheckpointTestGPU runs the e2e test for RHAI features with checkpointing (GPU)
96- func RunRhaiFeaturesCheckpointTestGPU (t * testing.T ) {
71+ // RunRhaiFeaturesCheckpointTest runs the e2e test for RHAI features with checkpointing
72+ func RunRhaiFeaturesCheckpointTest (t * testing.T , accelerator Accelerator ) {
9773 runRhaiFeaturesTestWithConfig (t , RhaiFeatureConfig {
9874 EnableProgressionTracking : false ,
9975 EnableJitCheckpoint : true ,
10076 CheckpointOutputDir : "/workspace/checkpoints" ,
10177 CheckpointSaveStrategy : "epoch" ,
10278 CheckpointSaveTotalLimit : "3" ,
103- Accelerator : NVIDIA ,
79+ Accelerator : accelerator ,
10480 })
10581}
10682
107- // RunRhaiFeaturesAllTest runs the e2e test for RHAI features with both progression tracking and checkpointing (CPU)
108- func RunRhaiFeaturesAllTest (t * testing.T ) {
83+ // RunRhaiFeaturesAllTest runs the e2e test for RHAI features with both progression tracking and checkpointing
84+ func RunRhaiFeaturesAllTest (t * testing.T , accelerator Accelerator ) {
10985 runRhaiFeaturesTestWithConfig (t , RhaiFeatureConfig {
11086 EnableProgressionTracking : true ,
11187 EnableJitCheckpoint : true ,
11288 CheckpointOutputDir : "/workspace/checkpoints" ,
11389 CheckpointSaveStrategy : "epoch" ,
11490 CheckpointSaveTotalLimit : "3" ,
115- Accelerator : CPU ,
116- })
117- }
118-
119- // RunRhaiFeaturesAllTestGPU runs the e2e test for RHAI features with both features (GPU)
120- func RunRhaiFeaturesAllTestGPU (t * testing.T ) {
121- runRhaiFeaturesTestWithConfig (t , RhaiFeatureConfig {
122- EnableProgressionTracking : true ,
123- EnableJitCheckpoint : true ,
124- CheckpointOutputDir : "/workspace/checkpoints" ,
125- CheckpointSaveStrategy : "epoch" ,
126- CheckpointSaveTotalLimit : "3" ,
127- Accelerator : NVIDIA ,
91+ Accelerator : accelerator ,
12892 })
12993}
13094
@@ -186,6 +150,7 @@ func runRhaiFeaturesTestWithConfig(t *testing.T, config RhaiFeatureConfig) {
186150
187151 shellCmd := fmt .Sprintf (
188152 "set -e; " +
153+ "export IPYTHONDIR='/tmp/.ipython'; " +
189154 "export OPENSHIFT_API_URL='%s'; " +
190155 "export NOTEBOOK_TOKEN='%s'; " +
191156 "export NOTEBOOK_NAMESPACE='%s'; " +
@@ -410,7 +375,7 @@ func verifyCheckpoints(test Test, namespace, trainJobName, checkpointDir string)
410375 }
411376 }
412377 return runningCount
413- }, TestTimeoutMedium , 5 * time .Second ).Should (BeNumerically (">" , 0 ), "At least one training pod should be running" )
378+ }, TestTimeoutLong , 5 * time .Second ).Should (BeNumerically (">" , 0 ), "At least one training pod should be running" )
414379 test .T ().Log ("Training pods are running" )
415380
416381 // Wait for 1st epoch to complete before suspending
@@ -491,19 +456,39 @@ func verifyCheckpoints(test Test, namespace, trainJobName, checkpointDir string)
491456 }, TestTimeoutMedium , 5 * time .Second ).Should (BeNumerically (">" , 0 ), "Training pods should start after resume" )
492457 test .T ().Log ("New training pods started" )
493458
494- // Wait for training to make progress after resume, then capture progress
495- test .T ().Log ("Waiting for training to make progress after resume..." )
496- test .Eventually (func () int {
497- return getProgressPercentage (test , namespace , trainJobName )
498- }, TestTimeoutMedium , 5 * time .Second ).Should (BeNumerically (">" , 0 ), "Training should show progress after resume" )
459+ // Capture pre-suspend state for comparison
460+ preSuspendEpoch := getCurrentEpoch (test , namespace , trainJobName )
461+ preSuspendLastUpdated := getLastUpdatedTime (test , namespace , trainJobName )
462+ test .T ().Logf ("Pre-suspend state: epoch=%.2f, progress=%d%%, lastUpdatedTime=%s" , preSuspendEpoch , preSuspendProgress , preSuspendLastUpdated )
463+
464+ // Wait for new pods to update the annotation (lastUpdatedTime must change)
465+ // This ensures we're reading fresh data from resumed pods, not stale pre-suspend data
466+ test .T ().Log ("Waiting for resumed pods to report fresh progress..." )
467+ test .Eventually (func () bool {
468+ newLastUpdated := getLastUpdatedTime (test , namespace , trainJobName )
469+ if newLastUpdated != "" && newLastUpdated != preSuspendLastUpdated {
470+ test .T ().Logf ("Fresh progress detected: lastUpdatedTime=%s" , newLastUpdated )
471+ return true
472+ }
473+ return false
474+ }, TestTimeoutMedium , 5 * time .Second ).Should (BeTrue (), "Annotation should be updated by resumed pods (lastUpdatedTime should change)" )
475+
476+ // Now read the fresh values reported by resumed pods
477+ initialResumeEpoch := getCurrentEpoch (test , namespace , trainJobName )
478+ initialResumeProgress := getProgressPercentage (test , namespace , trainJobName )
479+ test .T ().Logf ("Initial after resume: epoch=%.2f, progress=%d%%" , initialResumeEpoch , initialResumeProgress )
480+
481+ // Epoch after resume should be >= pre-suspend (checkpoint preserves state)
482+ test .Expect (initialResumeEpoch ).To (BeNumerically (">=" , preSuspendEpoch ),
483+ fmt .Sprintf ("Initial epoch after resume (%.2f) should be >= pre-suspend epoch (%.2f) - checkpoint should preserve training state" , initialResumeEpoch , preSuspendEpoch ))
499484
500- postResumeProgress := getProgressPercentage (test , namespace , trainJobName )
501- test .T ().Logf ("Progress AFTER resume: %d%%" , postResumeProgress )
485+ // Progress percentage should also be preserved
486+ test .Expect (initialResumeProgress ).To (BeNumerically (">=" , preSuspendProgress ),
487+ fmt .Sprintf ("Initial progress after resume (%d%%) should be >= pre-suspend progress (%d%%) - checkpoint should preserve progress" , initialResumeProgress , preSuspendProgress ))
502488
503- // Verify progress after resume is >= progress before suspension (checkpoint worked)
504- test .T ().Logf ("Checkpoint validation: pre-suspend=%d%%, post-resume=%d%%" , preSuspendProgress , postResumeProgress )
505- test .Expect (postResumeProgress ).To (BeNumerically (">=" , preSuspendProgress ),
506- fmt .Sprintf ("Progress after resume (%d%%) should be >= before suspend (%d%%) - checkpoint should preserve progress" , postResumeProgress , preSuspendProgress ))
489+ test .T ().Log ("Checkpoint verification: Training resumed from correct epoch and progress" )
490+
491+ test .T ().Log ("Waiting for resumed training to complete..." )
507492
508493 // Step 5: Wait for training to complete or fail (to get final state)
509494 test .T ().Log ("Step 5: Waiting for training to complete after resume..." )
@@ -569,6 +554,50 @@ func getProgressPercentage(test Test, namespace, trainJobName string) int {
569554 return 0
570555}
571556
557+ // getCurrentEpoch extracts currentEpoch from trainerStatus annotation
558+ func getCurrentEpoch (test Test , namespace , trainJobName string ) float64 {
559+ test .T ().Helper ()
560+
561+ trainJob := TrainJob (test , namespace , trainJobName )(test )
562+ annotations := trainJob .GetAnnotations ()
563+ statusJSON , ok := annotations [annotationTrainerStatus ]
564+ if ! ok {
565+ return 0
566+ }
567+
568+ var status map [string ]interface {}
569+ if err := json .Unmarshal ([]byte (statusJSON ), & status ); err != nil {
570+ return 0
571+ }
572+
573+ if epoch , ok := status ["currentEpoch" ].(float64 ); ok {
574+ return epoch
575+ }
576+ return 0
577+ }
578+
579+ // getLastUpdatedTime extracts lastUpdatedTime from trainerStatus annotation
580+ func getLastUpdatedTime (test Test , namespace , trainJobName string ) string {
581+ test .T ().Helper ()
582+
583+ trainJob := TrainJob (test , namespace , trainJobName )(test )
584+ annotations := trainJob .GetAnnotations ()
585+ statusJSON , ok := annotations [annotationTrainerStatus ]
586+ if ! ok {
587+ return ""
588+ }
589+
590+ var status map [string ]interface {}
591+ if err := json .Unmarshal ([]byte (statusJSON ), & status ); err != nil {
592+ return ""
593+ }
594+
595+ if lastUpdated , ok := status ["lastUpdatedTime" ].(string ); ok {
596+ return lastUpdated
597+ }
598+ return ""
599+ }
600+
572601// suspendTrainJob toggles the suspend state of a TrainJob using patch
573602func suspendTrainJob (test Test , namespace , trainJobName string , suspend bool ) {
574603 test .T ().Helper ()
0 commit comments