Skip to content

Conversation

@PopSoda2002
Copy link
Contributor

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @PopSoda2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enables Supervised Fine-Tuning (SFT) capabilities within the FSDP training framework. It integrates a dedicated SFT loss function into the training pipeline, refines the data preparation process for SFT samples, and ensures compatibility with the existing reward handling system. The changes are validated with a new end-to-end test case, marking a significant step towards supporting SFT alongside other training paradigms.

Highlights

  • SFT Loss Integration: Introduced a new SFT loss calculation path in the FSDP actor, which computes Negative Log Likelihood when loss_type is set to 'sft_loss'. This allows the training backend to switch between PPO-style losses and SFT loss based on configuration.
  • SFT Sample Processing: Enhanced the SFT rollout mechanism to correctly prepare samples for training. This includes handling grouped samples, appending the assistant's response (label) to the prompt to form complete messages for tokenization, and robustly generating loss masks, even for zero-length responses.
  • Reward Handling for SFT: Modified the reward post-processing logic to explicitly convert None reward values to 0.0, accommodating scenarios where SFT samples might not have explicit reward signals.
  • New FSDP SFT Test: Added a new test case (test_qwen3_4B_fsdp_sft.py) to validate the FSDP SFT functionality using the Qwen3-4B model, ensuring proper configuration and execution of SFT training with the FSDP backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Supervised Fine-Tuning (SFT) with FSDP. The changes include a new loss calculation path for SFT in the FSDP actor, modifications to the rollout process to handle SFT data, and a new test case. My review found a critical security issue with a hardcoded API key in the new test file. I've also identified a few areas for improvement, including removing a duplicated line of code, replacing debug print statements with logging, and simplifying some logic for better readability.

os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
# SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE=True
# os.environ["SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE"] = SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE
WANDB_API_KEY = "a37f4796e6205800c4212556a38e1319b5f144b7"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

A hardcoded W&B API key has been found. This is a significant security risk. Secrets like API keys should never be hardcoded in source code. Instead, they should be loaded from a secure source, such as environment variables. Please remove this key and update the code to fetch it from the environment.

Suggested change
WANDB_API_KEY = "a37f4796e6205800c4212556a38e1319b5f144b7"
WANDB_API_KEY = os.environ.get("WANDB_API_KEY")

local_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches],
loss_masks=loss_masks,
)
if getattr(self.args, "loss_type", "policy_loss") == "sft_loss":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line is a duplicate of line 601 and should be removed. Calling unpack_sequences twice on the same packed_batch is redundant and inefficient.

Comment on lines +20 to +25
wandb_args = (
"--use-wandb "
"--wandb-project miles-lora "
"--wandb-group lora1-chunk16_True "
f"--wandb-key {WANDB_API_KEY} "
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

After removing the hardcoded API key, you should conditionally add the --wandb-key argument only if the key is available from the environment. This prevents passing --wandb-key None to the training script if the environment variable is not set.

Suggested change
wandb_args = (
"--use-wandb "
"--wandb-project miles-lora "
"--wandb-group lora1-chunk16_True "
f"--wandb-key {WANDB_API_KEY} "
)
wandb_args = (
"--use-wandb "
"--wandb-project miles-lora "
"--wandb-group lora1-chunk16_True "
)
if WANDB_API_KEY:
wandb_args += f"--wandb-key {WANDB_API_KEY} "

f"sft_rollout::generate_rollout example data: {sample=} (raw){messages=} (raw){token_ids=} (raw){loss_mask=} {response_length=}"
)
SAMPLE_PRINTED = True
print(f"sft_rollout::generate_rollout samples shape: {len(samples)}, {len(samples[0])}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It's better to use the logging module for such output to have consistent logging and control over verbosity. Consider replacing it with logger.debug(...) or removing it if it's no longer needed.

Suggested change
print(f"sft_rollout::generate_rollout samples shape: {len(samples)}, {len(samples[0])}")
logger.debug(f"sft_rollout::generate_rollout samples shape: {len(samples)}, {len(samples[0])}")

Comment on lines +50 to +54
if isinstance(messages, (list, tuple)) or hasattr(messages, "tolist"):
if hasattr(messages, "tolist"):
messages = messages.tolist()
else:
messages = list(messages)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to ensure messages is a list can be simplified for better readability and to avoid unnecessarily creating a copy when the input is already a list.

Suggested change
if isinstance(messages, (list, tuple)) or hasattr(messages, "tolist"):
if hasattr(messages, "tolist"):
messages = messages.tolist()
else:
messages = list(messages)
if hasattr(messages, "tolist"):
# Handle numpy arrays and other objects with a tolist method
messages = messages.tolist()
elif isinstance(messages, tuple):
# Convert tuple to list to allow concatenation
messages = list(messages)

# Append the label/response to the messages
full_messages = messages + [{"role": "assistant", "content": sample.label}]

# print(f"sft_rollout::generate_rollout full_messages: {full_messages}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out print statement should be removed from the codebase to improve clarity and reduce clutter.

@yushengsu-thu yushengsu-thu self-assigned this Jan 2, 2026
@PopSoda2002
Copy link
Contributor Author

Closed as implemented by THUDM/slime#1298

@PopSoda2002 PopSoda2002 closed this Jan 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants