feat(dflash): add checkpoint resume support and clean up comments#464
Conversation
Summary of ChangesHello @xiaomin-D, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the DFlash training script by integrating robust checkpoint resume functionality. Users can now specify a checkpoint directory or automatically resume from the latest saved state, which includes restoring model weights and the learning rate scheduler. This improves the fault tolerance and flexibility of long-running training jobs. Additionally, the PR refactors the training script by removing unnecessary comments, making the codebase cleaner and easier to maintain. The accompanying example script has also been updated to reflect current best practices and dataset configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces checkpoint resume functionality for DFlash training and cleans up some comments. The implementation of the resume logic has a couple of critical issues. First, the checkpoint discovery logic is flawed and will not correctly identify the latest checkpoint, preventing the resume feature from working as intended. Second, when resuming, the optimizer's state is not restored, which can negatively impact training convergence. I've provided specific comments and code suggestions to address these critical problems. The comment cleanup is a good improvement for code readability.
| draft_model_last_checkpoint = get_last_checkpoint( | ||
| args.output_dir, prefix=r"epoch_\d+_step" | ||
| ) |
There was a problem hiding this comment.
The call to get_last_checkpoint is incorrect. The prefix argument is not treated as a regular expression, so r"epoch_\d+_step" will not match checkpoint directories like epoch_0_step_1000. Additionally, the get_last_checkpoint function's sorting logic is not suitable for checkpoint names with both epoch and step numbers, as it only sorts by a single trailing number. This will cause the auto-resume functionality to fail or pick the wrong checkpoint. A more robust method is needed to parse and sort checkpoint directories by both epoch and step.
| draft_model_last_checkpoint = get_last_checkpoint( | |
| args.output_dir, prefix=r"epoch_\d+_step" | |
| ) | |
| checkpoint_dirs = [d for d in os.listdir(args.output_dir) if d.startswith("epoch_") and os.path.isdir(os.path.join(args.output_dir, d))] | |
| if checkpoint_dirs: | |
| latest_checkpoint_dir = max(checkpoint_dirs, key=lambda d: [int(s) for s in d.split("_") if s.isdigit()]) | |
| draft_model_last_checkpoint = os.path.join(args.output_dir, latest_checkpoint_dir) |
| start_epoch = 0 | ||
| global_step = 0 | ||
| if resume_state is not None: | ||
| optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) |
There was a problem hiding this comment.
When resuming from a checkpoint, only the learning rate scheduler's state is being restored. The optimizer's state (e.g., momentum buffers for Adam) is not loaded, which can negatively impact training convergence. The BF16Optimizer class provides a load_state_dict method that correctly restores both the optimizer and scheduler states.
| optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) | |
| optimizer.load_state_dict(resume_state) |
| f"Provided ckpt dir {args.ckpt_dir} is not a valid directory." | ||
| ) | ||
|
|
||
| if args.resume and os.path.isdir(args.output_dir): |
There was a problem hiding this comment.
If both --ckpt-dir and --resume are specified, the auto-detection from --resume will override the explicit path provided via --ckpt-dir. The explicit path should have higher precedence. Using elif here will ensure that auto-detection is only attempted if --ckpt-dir is not provided.
| if args.resume and os.path.isdir(args.output_dir): | |
| elif args.resume and os.path.isdir(args.output_dir): |
Motivation
Add --ckpt-dir and --resume flags for DFlash training checkpoint resume, aligned with eagle3 training script pattern.
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist