-
Notifications
You must be signed in to change notification settings - Fork 89
[Feat.][FSDP] Support FSDP SFT #380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @PopSoda2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enables Supervised Fine-Tuning (SFT) capabilities within the FSDP training framework. It integrates a dedicated SFT loss function into the training pipeline, refines the data preparation process for SFT samples, and ensures compatibility with the existing reward handling system. The changes are validated with a new end-to-end test case, marking a significant step towards supporting SFT alongside other training paradigms. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Supervised Fine-Tuning (SFT) with FSDP. The changes include a new loss calculation path for SFT in the FSDP actor, modifications to the rollout process to handle SFT data, and a new test case. My review found a critical security issue with a hardcoded API key in the new test file. I've also identified a few areas for improvement, including removing a duplicated line of code, replacing debug print statements with logging, and simplifying some logic for better readability.
| os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES | ||
| # SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE=True | ||
| # os.environ["SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE"] = SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE | ||
| WANDB_API_KEY = "a37f4796e6205800c4212556a38e1319b5f144b7" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A hardcoded W&B API key has been found. This is a significant security risk. Secrets like API keys should never be hardcoded in source code. Instead, they should be loaded from a secure source, such as environment variables. Please remove this key and update the code to fetch it from the environment.
| WANDB_API_KEY = "a37f4796e6205800c4212556a38e1319b5f144b7" | |
| WANDB_API_KEY = os.environ.get("WANDB_API_KEY") |
| local_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches], | ||
| loss_masks=loss_masks, | ||
| ) | ||
| if getattr(self.args, "loss_type", "policy_loss") == "sft_loss": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| wandb_args = ( | ||
| "--use-wandb " | ||
| "--wandb-project miles-lora " | ||
| "--wandb-group lora1-chunk16_True " | ||
| f"--wandb-key {WANDB_API_KEY} " | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing the hardcoded API key, you should conditionally add the --wandb-key argument only if the key is available from the environment. This prevents passing --wandb-key None to the training script if the environment variable is not set.
| wandb_args = ( | |
| "--use-wandb " | |
| "--wandb-project miles-lora " | |
| "--wandb-group lora1-chunk16_True " | |
| f"--wandb-key {WANDB_API_KEY} " | |
| ) | |
| wandb_args = ( | |
| "--use-wandb " | |
| "--wandb-project miles-lora " | |
| "--wandb-group lora1-chunk16_True " | |
| ) | |
| if WANDB_API_KEY: | |
| wandb_args += f"--wandb-key {WANDB_API_KEY} " |
| f"sft_rollout::generate_rollout example data: {sample=} (raw){messages=} (raw){token_ids=} (raw){loss_mask=} {response_length=}" | ||
| ) | ||
| SAMPLE_PRINTED = True | ||
| print(f"sft_rollout::generate_rollout samples shape: {len(samples)}, {len(samples[0])}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This print statement appears to be for debugging purposes. It's better to use the logging module for such output to have consistent logging and control over verbosity. Consider replacing it with logger.debug(...) or removing it if it's no longer needed.
| print(f"sft_rollout::generate_rollout samples shape: {len(samples)}, {len(samples[0])}") | |
| logger.debug(f"sft_rollout::generate_rollout samples shape: {len(samples)}, {len(samples[0])}") |
| if isinstance(messages, (list, tuple)) or hasattr(messages, "tolist"): | ||
| if hasattr(messages, "tolist"): | ||
| messages = messages.tolist() | ||
| else: | ||
| messages = list(messages) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to ensure messages is a list can be simplified for better readability and to avoid unnecessarily creating a copy when the input is already a list.
| if isinstance(messages, (list, tuple)) or hasattr(messages, "tolist"): | |
| if hasattr(messages, "tolist"): | |
| messages = messages.tolist() | |
| else: | |
| messages = list(messages) | |
| if hasattr(messages, "tolist"): | |
| # Handle numpy arrays and other objects with a tolist method | |
| messages = messages.tolist() | |
| elif isinstance(messages, tuple): | |
| # Convert tuple to list to allow concatenation | |
| messages = list(messages) |
| # Append the label/response to the messages | ||
| full_messages = messages + [{"role": "assistant", "content": sample.label}] | ||
|
|
||
| # print(f"sft_rollout::generate_rollout full_messages: {full_messages}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Closed as implemented by THUDM/slime#1298 |
No description provided.