-
-
Notifications
You must be signed in to change notification settings - Fork 9.4k
[Misc] DeepSeek Decode Optimizations #19807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
varun-sundar-rabindranath
wants to merge
28
commits into
vllm-project:main
from
neuralmagic:varun/deepseek-decode-opt
Closed
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
4d518d1
add batched silu mul
ea96ddd
Refactor per_token_group_quant
abcf846
add batched per token quant
50162ac
batched -> masked
9f13fb0
batched_utils -> masked_kernels
b82cbe5
batched -> masked
b2b365d
add masked silu-mul test
57dc316
fixes and add batched per-token-quant tests
dcace53
relax silu mul tolerance
7ba8335
plugin masked kernels
06d28b2
fix D blocking
8f9cb3d
better testing
c98c2e2
make out_q optional
c20487e
fixes
2fb3d5f
add batched cuda silu and mul
97bda02
fixes
00ccbd4
fixes
fc5bc04
add batched impl tests
67e76b5
use cuda silu mul
8de2fd3
deep_ep + use_fp8_dispatch
b2178be
Quantize kernel with the layout that deepgemm wants
tlrmchlsmth a2b4f8e
rm bad assert
tlrmchlsmth 3e76435
Merge branch 'deepgemm_layout_scales' into varun/deepseek-decode-opt
02fcad1
Merge branch 'varun/deepep-fp8-dispatch' into varun/deepseek-decode-opt
041a9e9
fixes - use-fp8-dispatch
2a7e537
fix topk ids
a2eb4f9
update batched silu mul kernel
fffaf97
fix fp8 dispatch tests
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
""" | ||
Test for masked utility kernels. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm.model_executor.layers.fused_moe.masked_kernels import ( | ||
invoke_masked_silu_and_mul, masked_per_token_group_quant_fp8) | ||
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||
per_token_group_quant_fp8) | ||
from vllm.platforms import current_platform | ||
|
||
|
||
def ref_silu_mul(x, out, valid_tokens_array): | ||
|
||
valid_tokens_array = valid_tokens_array.to("cpu") | ||
batch_size = x.size(0) | ||
for b in range(batch_size): | ||
# num valid tokens | ||
n = valid_tokens_array[b] | ||
if n == 0: | ||
continue | ||
torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :]) | ||
|
||
|
||
def ref_per_token_group_quant( | ||
x: torch.Tensor, x_q: torch.Tensor, valid_tokens_array: torch.Tensor, | ||
group_size: int, | ||
column_major_scales: bool) -> tuple[torch.Tensor, torch.Tensor]: | ||
assert x.shape == x_q.shape | ||
|
||
# make scales tensor | ||
B, NUM_TOKENS, HIDDEN_SIZE = x.shape | ||
x_q_s = torch.empty((B, NUM_TOKENS, HIDDEN_SIZE // group_size), | ||
device="cuda", | ||
dtype=torch.float32) | ||
|
||
valid_tokens_array = valid_tokens_array.to("cpu") | ||
batch_size = x.size(0) | ||
for b in range(batch_size): | ||
# num valid tokens | ||
n = valid_tokens_array[b] | ||
if n == 0: | ||
continue | ||
x_slice = x[b, :n, :] | ||
xq_slice, xqs_slice = per_token_group_quant_fp8( | ||
x_slice, group_size, column_major_scales=column_major_scales) | ||
x_q[b, :n, :].copy_(xq_slice) | ||
x_q_s[b, :n, :].copy_(xqs_slice) | ||
|
||
return x_q, x_q_s | ||
|
||
|
||
BATCH_SIZES = [1, 13, 26, 32] | ||
NUM_TOKENS = [7, 37, 64, 4096] | ||
|
||
## Tests for masked per_token_group_quant_fp8 #### | ||
|
||
HIDDEN_SIZES = [128, 256, 384, 512, 1024] | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", | ||
[torch.float16, torch.bfloat16, torch.float32]) | ||
@pytest.mark.parametrize("column_major_scales", [True]) | ||
def test_masked_per_token_group_quant_fp8(batch_size: int, num_tokens: int, | ||
hidden_size: int, dtype: torch.dtype, | ||
column_major_scales: bool): | ||
|
||
DEEPGEMM_BLOCK_SIZE = 128 | ||
|
||
input = torch.randn( | ||
(batch_size, num_tokens, hidden_size), device="cuda", | ||
dtype=dtype) / 10.0 | ||
|
||
out_q = torch.randn((batch_size, num_tokens, hidden_size), device="cuda") | ||
out_q = out_q.to(dtype=current_platform.fp8_dtype()) | ||
|
||
ref_out_q = torch.empty_like(out_q) | ||
ref_out_q.copy_(out_q) | ||
|
||
# valid num_tokens per batch | ||
valid_num_tokens = torch.randint(low=0, | ||
high=num_tokens + 1, | ||
size=(batch_size, ), | ||
device="cuda").to(torch.int32) | ||
|
||
# Reference | ||
ref_out_q, ref_out_scales = ref_per_token_group_quant( | ||
x=input, | ||
x_q=ref_out_q, | ||
valid_tokens_array=valid_num_tokens, | ||
group_size=DEEPGEMM_BLOCK_SIZE, | ||
column_major_scales=column_major_scales) | ||
|
||
# Impl | ||
out_q, out_scales = masked_per_token_group_quant_fp8( | ||
x=input, | ||
x_q=out_q, | ||
valid_tokens_array=valid_num_tokens, | ||
group_size=DEEPGEMM_BLOCK_SIZE, | ||
column_major_scales=column_major_scales) | ||
|
||
torch.testing.assert_close(ref_out_q, out_q) | ||
|
||
valid_num_tokens_cpu = valid_num_tokens.to(device="cpu") | ||
for b in range(valid_num_tokens_cpu.size(0)): | ||
n = valid_num_tokens_cpu[b] | ||
torch.testing.assert_close(ref_out_scales[b, :n, :], | ||
out_scales[b, :n, :]) | ||
|
||
|
||
## Tests for masked silu_and_mul #### | ||
|
||
HIDDEN_SIZES = [124, 1024, 2176, 2816] | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", | ||
[torch.float16, torch.bfloat16, torch.float32]) | ||
def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype): | ||
|
||
input = torch.randn( | ||
(batch_size, num_tokens, hidden_size), device="cuda", | ||
dtype=dtype) / 10.0 | ||
|
||
out = torch.empty((batch_size, num_tokens, hidden_size // 2), | ||
device="cuda", | ||
dtype=dtype) | ||
|
||
ref_out = torch.empty_like(out) | ||
ref_out.copy_(out) | ||
|
||
# valid num_tokens per batch | ||
valid_num_tokens = torch.randint(low=0, | ||
high=num_tokens + 1, | ||
size=(batch_size, ), | ||
device="cuda").to(torch.int32) | ||
|
||
# reference | ||
ref_silu_mul(input, ref_out, valid_num_tokens) | ||
|
||
# impl | ||
invoke_masked_silu_and_mul(out, input, valid_num_tokens) | ||
|
||
torch.testing.assert_close(ref_out, out) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This semicolon is unnecessary and can be removed.