Skip to content

fix: disable attn_tp_input_scattered when input_embeds is provided externally for Kimi-K2.5#21215

Closed
qingchanghan wants to merge 1 commit intosgl-project:mainfrom
rednote-ai:fix/attn-tp-scattered-multimodal
Closed

fix: disable attn_tp_input_scattered when input_embeds is provided externally for Kimi-K2.5#21215
qingchanghan wants to merge 1 commit intosgl-project:mainfrom
rednote-ai:fix/attn-tp-scattered-multimodal

Conversation

@qingchanghan
Copy link

Motivation

When --enable-attn-tp-input-scattered is used with multimodal models based on DeepSeek V3 architecture (e.g., Kimi-K2.5), the model produces garbled/garbage output. Pure text models like DeepSeek-R1 are not affected.

Root cause: Multimodal models use general_mm_embed_routine (in mm_utils.py) which computes embeddings outside the maybe_input_scattered context manager. The resulting input_embeds tensor is already full (all_reduced). When this tensor is passed to DeepseekV3ForCausalLM.forward(), the maybe_input_scattered context activates input_scattered=True, causing attention layers to incorrectly all_gather already-complete data — producing 8x repeated/corrupted tensors and garbled output.

In contrast, pure text models (DeepSeek-R1) call embed_tokens(input_ids) inside the maybe_input_scattered context, where VocabParallelEmbedding.forward() correctly skips all_reduce, producing genuinely scattered data that subsequent all_gather operations can correctly reconstruct.

Call path comparison:

DeepSeek-R1 (WORKS):
  DeepseekV3ForCausalLM.forward(input_ids=tokens, input_embeds=None)
    └─ with maybe_input_scattered():  # input_scattered = True
         └─ embed_tokens(input_ids)   # skips all_reduce → scattered output ✓
         └─ attention layers all_gather scattered data → correct ✓

Kimi-K2.5 (BROKEN):
  general_mm_embed_routine()
    └─ embed_tokens(input_ids)        # input_scattered=False → all_reduce → FULL output
    └─ language_model(input_embeds=FULL_EMBEDS)
         └─ with maybe_input_scattered():  # input_scattered = True
              └─ uses FULL input_embeds directly (no embed_tokens call)
              └─ attention layers all_gather FULL data → corrupted ✗

Modifications

python/sglang/srt/models/deepseek_v2.py:

  • In DeepseekV3ForCausalLM.forward(), pass None instead of forward_batch to maybe_input_scattered when input_embeds is provided externally. This disables scattered mode for externally-computed embeddings that are already full.

python/sglang/srt/layers/communicator.py:

  • Add forward_batch is not None guard in use_input_scattered() to safely handle None input from the above change.

Accuracy Tests

Tested on 8x NVIDIA H20 with sglang 0.5.9, Kimi-K2.5 (BF16, TP=8):

Config Before Fix After Fix
--enable-attn-tp-input-scattered --chunked-prefill-size 32768 Garbled output Correct output
Pure TP=8 (no scattered) Correct output Correct output (unchanged)

After fix, sample outputs:

  • Input: "请用中文简单介绍一下你自己" → Output: "你好!我是 Kimi,由月之暗面(Moonshot AI)开发的人工智能助手。" ✓
  • Input: "What is 2+3? Answer in one word." → Output: "five" ✓

DeepSeek-R1 behavior is unchanged (was already working correctly, and the fix does not alter its code path since input_embeds is None for pure text models).

Benchmarking and Profiling

This fix only adds a conditional check (input_embeds is None) before entering the scattered context. No performance impact on pure text models. For multimodal models, scattered mode is disabled when external embeddings are provided — this trades the scattered optimization for correctness. A future optimization could scatter the pre-computed embeddings to re-enable the optimization path.

Checklist

…ternally

Multimodal models like Kimi-K2.5 compute embeddings via
general_mm_embed_routine outside the maybe_input_scattered context,
producing full (all_reduced) input_embeds. When these are passed to
DeepseekV3ForCausalLM.forward(), the scattered mode incorrectly
treats them as TP-partitioned data, causing attention layers to
all_gather already-complete tensors and producing garbled output.

Fix by passing None as forward_batch to maybe_input_scattered when
input_embeds is externally provided, which disables scattered mode.
Also add a None guard in use_input_scattered for safety.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug affecting multimodal models, specifically those based on the DeepSeek V3 architecture like Kimi-K2.5, when the attn_tp_input_scattered optimization is enabled. The issue stemmed from an incorrect interaction between pre-computed input_embeds and the scattered input mechanism, leading to data corruption and garbled model output. The solution involves conditionally disabling the scattered input mode when external embeddings are provided, ensuring data integrity and correct model behavior for multimodal inputs, while preserving existing functionality and performance for pure text models.

Highlights

  • Fixes Garbled Output in Multimodal Models: Resolved an issue where multimodal models (e.g., Kimi-K2.5) produced garbled output when --enable-attn-tp-input-scattered was enabled and input_embeds were provided.
  • Addresses Incorrect all_gather: The root cause was identified as attention layers incorrectly performing all_gather on input_embeds that were already all-reduced, leading to corrupted data.
  • Conditionally Disables Scattered Mode: Modified DeepseekV3ForCausalLM.forward() to pass None to maybe_input_scattered when input_embeds are externally provided, effectively disabling scattered mode for pre-computed embeddings.
  • Adds None Guard for forward_batch: Introduced a forward_batch is not None check in use_input_scattered() to safely handle the None input passed from the model's forward method.
  • Validated Accuracy and Performance: Confirmed that the fix resolves the garbled output for Kimi-K2.5 and maintains correct behavior for pure text models like DeepSeek-R1. The performance impact is minimal, trading scattered optimization for correctness in multimodal models.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@qingchanghan qingchanghan changed the title fix: disable attn_tp_input_scattered when input_embeds is provided ex… fix: disable attn_tp_input_scattered when input_embeds is provided externally for Kimi-K2.5 Mar 23, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a bug causing garbled output in multimodal models when using --enable-attn-tp-input-scattered. The root cause, related to how externally computed embeddings are handled in scattered mode, is well understood and the fix is logical. Disabling scattered mode when input_embeds are provided is a sensible solution. The corresponding safeguard in use_input_scattered to handle None is also correctly implemented. The changes are clear and well-commented. I have one minor suggestion to improve the type hinting for better code clarity and safety.

@@ -220,6 +219,7 @@ def init_context(self, q_lora_rank, is_nsa):
def use_input_scattered(self, forward_batch: ForwardBatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since forward_batch can now be None (as checked on line 222), its type hint should be updated to Optional[ForwardBatch] to maintain type safety and improve code clarity. Please also ensure Optional is imported from the typing module at the top of the file (e.g., from typing import Optional).

Suggested change
def use_input_scattered(self, forward_batch: ForwardBatch):
def use_input_scattered(self, forward_batch: Optional[ForwardBatch]):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant