fix: disable attn_tp_input_scattered when input_embeds is provided externally for Kimi-K2.5#21215
Conversation
…ternally Multimodal models like Kimi-K2.5 compute embeddings via general_mm_embed_routine outside the maybe_input_scattered context, producing full (all_reduced) input_embeds. When these are passed to DeepseekV3ForCausalLM.forward(), the scattered mode incorrectly treats them as TP-partitioned data, causing attention layers to all_gather already-complete tensors and producing garbled output. Fix by passing None as forward_batch to maybe_input_scattered when input_embeds is externally provided, which disables scattered mode. Also add a None guard in use_input_scattered for safety.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug affecting multimodal models, specifically those based on the DeepSeek V3 architecture like Kimi-K2.5, when the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request effectively addresses a bug causing garbled output in multimodal models when using --enable-attn-tp-input-scattered. The root cause, related to how externally computed embeddings are handled in scattered mode, is well understood and the fix is logical. Disabling scattered mode when input_embeds are provided is a sensible solution. The corresponding safeguard in use_input_scattered to handle None is also correctly implemented. The changes are clear and well-commented. I have one minor suggestion to improve the type hinting for better code clarity and safety.
| @@ -220,6 +219,7 @@ def init_context(self, q_lora_rank, is_nsa): | |||
| def use_input_scattered(self, forward_batch: ForwardBatch): | |||
There was a problem hiding this comment.
Since forward_batch can now be None (as checked on line 222), its type hint should be updated to Optional[ForwardBatch] to maintain type safety and improve code clarity. Please also ensure Optional is imported from the typing module at the top of the file (e.g., from typing import Optional).
| def use_input_scattered(self, forward_batch: ForwardBatch): | |
| def use_input_scattered(self, forward_batch: Optional[ForwardBatch]): |
Motivation
When
--enable-attn-tp-input-scatteredis used with multimodal models based on DeepSeek V3 architecture (e.g., Kimi-K2.5), the model produces garbled/garbage output. Pure text models like DeepSeek-R1 are not affected.Root cause: Multimodal models use
general_mm_embed_routine(inmm_utils.py) which computes embeddings outside themaybe_input_scatteredcontext manager. The resultinginput_embedstensor is already full (all_reduced). When this tensor is passed toDeepseekV3ForCausalLM.forward(), themaybe_input_scatteredcontext activatesinput_scattered=True, causing attention layers to incorrectlyall_gatheralready-complete data — producing 8x repeated/corrupted tensors and garbled output.In contrast, pure text models (DeepSeek-R1) call
embed_tokens(input_ids)inside themaybe_input_scatteredcontext, whereVocabParallelEmbedding.forward()correctly skipsall_reduce, producing genuinely scattered data that subsequentall_gatheroperations can correctly reconstruct.Call path comparison:
Modifications
python/sglang/srt/models/deepseek_v2.py:DeepseekV3ForCausalLM.forward(), passNoneinstead offorward_batchtomaybe_input_scatteredwheninput_embedsis provided externally. This disables scattered mode for externally-computed embeddings that are already full.python/sglang/srt/layers/communicator.py:forward_batch is not Noneguard inuse_input_scattered()to safely handleNoneinput from the above change.Accuracy Tests
Tested on 8x NVIDIA H20 with sglang 0.5.9, Kimi-K2.5 (BF16, TP=8):
--enable-attn-tp-input-scattered --chunked-prefill-size 32768After fix, sample outputs:
DeepSeek-R1 behavior is unchanged (was already working correctly, and the fix does not alter its code path since
input_embeds is Nonefor pure text models).Benchmarking and Profiling
This fix only adds a conditional check (
input_embeds is None) before entering the scattered context. No performance impact on pure text models. For multimodal models, scattered mode is disabled when external embeddings are provided — this trades the scattered optimization for correctness. A future optimization could scatter the pre-computed embeddings to re-enable the optimization path.Checklist