Skip to content

Commit 254a792

Browse files
fix: small bug fixes
Signed-off-by: abhijeet-dhumal <[email protected]>
1 parent b89a6b8 commit 254a792

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/trainer/sdk_tests/rhai_features_tests.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,7 @@ func runRhaiFeaturesTestWithConfig(t *testing.T, config RhaiFeatureConfig) {
259259
// Verify progress metrics exist and have valid values
260260
test.Expect(trainerStatus).To(HaveKey("progressPercentage"))
261261
progress := trainerStatus["progressPercentage"].(float64)
262-
test.Expect(progress).To(BeNumerically(">=", 0), "Progress should be non-negative")
263-
test.Expect(progress).To(BeNumerically("<=", 100), "Progress should not exceed 100%")
262+
test.Expect(progress).To(BeNumerically("==", 100), "Progress should be 100% at completion")
264263
test.T().Logf("progressPercentage: %.0f%%", progress)
265264

266265
test.Expect(trainerStatus).To(HaveKey("currentStep"))
@@ -273,7 +272,7 @@ func runRhaiFeaturesTestWithConfig(t *testing.T, config RhaiFeatureConfig) {
273272

274273
test.Expect(trainerStatus).To(HaveKey("estimatedRemainingSeconds"))
275274
remaining := trainerStatus["estimatedRemainingSeconds"].(float64)
276-
test.Expect(remaining).To(BeNumerically(">=", 0), "Remaining time should be non-negative")
275+
test.Expect(remaining).To(BeNumerically("==", 0), "Remaining time should be 0 at completion")
277276
test.T().Logf("estimatedRemainingSeconds: %.0f", remaining)
278277

279278
test.T().Log("Progression tracking verification passed!")
@@ -450,10 +449,16 @@ func verifyCheckpoints(test Test, namespace, trainJobName, checkpointDir string,
450449
var preSuspendProgress int
451450
var preSuspendEpoch float64
452451
if progressionEnabled {
452+
// Wait for operator to poll metrics and update TrainJob annotations
453+
// This avoids race condition where job is suspended before progress is tracked
454+
test.T().Log("Step 5: Waiting for progress to be tracked in TrainJob...")
455+
test.Eventually(func() int {
456+
return getProgressPercentage(test, namespace, trainJobName)
457+
}, TestTimeoutMedium, 5*time.Second).Should(BeNumerically(">", 0), "Progress should be tracked before suspension")
458+
453459
preSuspendProgress = getProgressPercentage(test, namespace, trainJobName)
454460
preSuspendEpoch = getCurrentEpoch(test, namespace, trainJobName)
455461
test.T().Logf("Pre-suspend state: epoch=%.2f, progress=%d%%", preSuspendEpoch, preSuspendProgress)
456-
test.Expect(preSuspendProgress).To(BeNumerically(">", 0), "Progress should be > 0 at suspension")
457462
}
458463

459464
// Step 6: Resume the TrainJob

0 commit comments

Comments
 (0)