Skip to content

[fsdp, model] feat: add sp for qwen3.5 fsdp grpo training#5920

Open
Zhang1Sheng wants to merge 1 commit intoverl-project:mainfrom
Zhang1Sheng:main
Open

[fsdp, model] feat: add sp for qwen3.5 fsdp grpo training#5920
Zhang1Sheng wants to merge 1 commit intoverl-project:mainfrom
Zhang1Sheng:main

Conversation

@Zhang1Sheng
Copy link
Copy Markdown
Contributor

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@Zhang1Sheng Zhang1Sheng changed the title [sp, model] feat: add sp for qwen3.5 fsdp grpo training [fsdp, model] feat: add sp for qwen3.5 fsdp grpo training Apr 8, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for Ulysses sequence parallelism (LASP) in Qwen 3.5 models by patching the Gated DeltaNet forward pass and the attention mask application logic. The changes include sharding depthwise convolution weights, implementing all-to-all communication for linear attention heads, and slicing parameters like A_log and dt_bias to align with local ranks. Review feedback identified critical shape mismatch issues in the patched forward pass, specifically regarding the use of the unpatched mask function and incorrect tensor dimensions for convolution updates and functions.

cache_params: Qwen3_5DynamicCache | None = None,
attention_mask: torch.Tensor | None = None,
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to apply_mask_to_padding_states uses the original function imported from transformers, which does not handle sequence parallelism slicing of the attention mask. This will lead to a shape mismatch error when ulysses_sp_size > 1 because hidden_states is sharded but attention_mask is not. Use the patched qwen3_5_apply_mask_to_padding_states defined in this file instead.

Suggested change
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
hidden_states = qwen3_5_apply_mask_to_padding_states(hidden_states, attention_mask)

Comment on lines +374 to +380
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When seq_len == 1, mixed_qkv has shape [B, 1, D]. However, causal_conv1d_update expects a 2D tensor of shape [B, D]. Additionally, the result must be unsqueezed to [B, D, 1] so that the subsequent transpose(1, 2) at line 422 correctly restores the [B, 1, D] shape required for the split operation at line 424.

Suggested change
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
mixed_qkv = self.causal_conv1d_update(
mixed_qkv.squeeze(1),
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
).unsqueeze(-1)

Comment on lines +403 to +409
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=conv_weight,
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

causal_conv1d_fn expects the input tensor x to have the channel dimension as the second dimension (shape [B, D, S]). Since the transpose at line 331 was removed, mixed_qkv is currently [B, S, D]. It must be transposed before calling the convolution function. The result will be [B, D, S], which is then correctly handled by the transpose(1, 2) at line 422.

Suggested change
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=conv_weight,
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv.transpose(1, 2),
weight=conv_weight,
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants