Skip to content

Commit 968742d

Browse files
Adding successful training completion validation to mnist test
1 parent 083084b commit 968742d

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

tests/trainer/resources/mnist.ipynb

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,41 @@
505505
},
506506
"outputs": [],
507507
"source": [
508-
"# Wait for the running status, then completion.\n",
509-
"client.wait_for_job_status(name=job_name, status={\"Running\"})\n",
510-
"client.wait_for_job_status(name=job_name, status={\"Complete\"})"
508+
"# Wait for the running status, then wait for completion or failure\n",
509+
"client.wait_for_job_status(name=job_name, status={\"Running\"}, timeout=300)\n",
510+
"client.wait_for_job_status(name=job_name, status={\"Complete\", \"Failed\"}, timeout=900)\n",
511+
"\n",
512+
"# Get job details and logs\n",
513+
"job = client.get_job(name=job_name)\n",
514+
"pod_logs = client.get_job_logs(name=job_name, follow=False)\n",
515+
"\n",
516+
"# Flatten all pod logs into a single list of lines\n",
517+
"logs = []\n",
518+
"for log_line in pod_logs:\n",
519+
" logs.extend(str(log_line).splitlines())\n",
520+
"\n",
521+
"log_text = \"\\n\".join(logs)\n",
522+
"\n",
523+
"print(f\"Training job final status: {job.status}\")\n",
524+
"\n",
525+
"# Check 1: Job status must not be \"Failed\" \n",
526+
"if job.status == \"Failed\":\n",
527+
" print(f\"ERROR: Training job '{job_name}' has Failed status\")\n",
528+
" print(\"Last 30 lines of logs:\")\n",
529+
" for line in logs[-30:]:\n",
530+
" print(line)\n",
531+
" raise RuntimeError(f\"Training job '{job_name}' failed\")\n",
532+
"\n",
533+
"# Check 2: Look for the training completion message in logs\n",
534+
"# This is critical because the training script may catch exceptions and exit 0\n",
535+
"if \"Training is finished\" not in log_text:\n",
536+
" print(f\"ERROR: Training completion message not found in logs\")\n",
537+
" print(\"Last 50 lines of logs:\")\n",
538+
" for line in logs[-50:]:\n",
539+
" print(line)\n",
540+
" raise RuntimeError(f\"Training did not complete successfully - missing completion message\")\n",
541+
"\n",
542+
"print(f\"✓ Training job '{job_name}' completed successfully\")"
511543
]
512544
},
513545
{

0 commit comments

Comments
 (0)