Skip to content

[Misc] DeepSeek Decode Optimizations #19807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,32 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
__device__ void _act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
const int d, const int64_t token_idx) {
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
}

// Activation and gating kernel template.

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
_act_and_mul_kernel<scalar_t, ACT_FN, act_first>(out, input, d, token_idx);
}

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
Expand Down Expand Up @@ -223,3 +233,72 @@ void gelu_quick(torch::Tensor& out, // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}

namespace vllm {
// Batched act_and_mul kernel template
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void batched_act_and_mul_kernel(
scalar_t* out, // [B, max_tokens, d]
const scalar_t* input, // [B, max_tokens, 2, d]
const int32_t* valid_tokens_array, // [B]
const int d, const int max_num_tokens) {
const int64_t batch_idx = blockIdx.x;
const int64_t num_tokens = valid_tokens_array[batch_idx];
if (num_tokens == 0) {
return;
}

int const col_offset = blockIdx.y * blockDim.x;
scalar_t* __restrict__ b_out =
&out[batch_idx * max_num_tokens * d + col_offset];
const scalar_t* __restrict__ b_in =
&input[batch_idx * max_num_tokens * d * 2 + col_offset];

int token_idx = 0;
const int tidx = threadIdx.x;
while (token_idx < num_tokens) {
if (col_offset + tidx < d) {
const scalar_t x = VLLM_LDG(&b_in[tidx]);
const scalar_t y = VLLM_LDG(&b_in[tidx + d]);
b_out[tidx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}

b_out += d;
b_in += (2 * d);

++token_idx;
}
}
} // namespace vllm

// Launch batched activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int64_t const batch_size = input.size(0); \
int64_t const max_num_tokens = input.size(1); \
int const d = input.size(2) / 2; \
int const block_size = std::min(d, 1024); \
int const blocks_per_row = ((d - 1) / block_size) + 1; \
dim3 grid(batch_size, blocks_per_row); \
dim3 block(block_size); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "batched_act_and_mul_kernel", [&] { \
vllm::batched_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, \
ACT_FIRST> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
valid_tokens_array.data_ptr<int32_t>(), d, max_num_tokens); \
});

void batched_silu_and_mul(torch::Tensor& out, // [B, max_tokens, d]
torch::Tensor& input, // [B, max_tokens, 2, d]
torch::Tensor& valid_tokens_array) // [B]
{
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(valid_tokens_array.dtype() == torch::kInt32);
LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void batched_silu_and_mul(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& valid_tokens_array);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
5 changes: 5 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);

ops.def(
"batched_silu_and_mul(Tensor! result, Tensor input, Tensor "
"valid_tokens_array) -> ()");
ops.impl("batched_silu_and_mul", torch::kCUDA, &batched_silu_and_mul);

ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

Expand Down
54 changes: 54 additions & 0 deletions tests/kernels/core/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
SiluAndMul)
from vllm.platforms import current_platform


def ref_batched_silu_mul(x, out, valid_tokens_array):
"""
Reference implementation of batched silu_and_mul
"""
valid_tokens_array = valid_tokens_array.to("cpu")
batch_size = x.size(0)
for b in range(batch_size):
# num valid tokens
n = valid_tokens_array[b]
if n == 0:
continue
torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :])


DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
Expand Down Expand Up @@ -106,3 +121,42 @@ def test_activation(

out = torch.empty_like(x)
opcheck(fn, (out, x))


## Test Batched Implementaion ####

BATCH_SIZES = [1, 13, 26, 32]
NUM_TOKENS = [7, 37, 64, 4096]
D = [128, 256, 384, 512, 1024, 1536, 13824]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_batched_silu_mul(batch_size: int, num_tokens: int, d: int,
dtype: torch.dtype):

input = torch.randn(
(batch_size, num_tokens, d), device="cuda", dtype=dtype) / 10.0

out = torch.empty((batch_size, num_tokens, d // 2),
device="cuda",
dtype=dtype)

ref_out = out.clone()

# valid num_tokens per batch
valid_num_tokens = torch.randint(low=0,
high=num_tokens + 1,
size=(batch_size, ),
device="cuda").to(dtype=torch.int32)

# reference
ref_batched_silu_mul(input, ref_out, valid_num_tokens)

# impl
torch.ops._C.batched_silu_and_mul(out, input, valid_num_tokens)

torch.testing.assert_close(ref_out, out)
32 changes: 21 additions & 11 deletions tests/kernels/moe/test_deepep_deepgemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Expand Down Expand Up @@ -169,15 +170,16 @@ def make(config: TestConfig, rank) -> "TestTensors":
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)

topk_ids = torch.randint(
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device()).to(dtype=torch.int64)
score = torch.randn((m, config.num_experts),
device="cuda",
dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(rank_tokens, score, topk, False)

topk_weights = torch.randn(topk_ids.shape,
dtype=torch.float32,
device=torch.cuda.current_device())
# overwrite topk_ids to distribute evenly.
topk_ids = torch.empty((m, topk), device="cpu", dtype=torch.int64)
for mi in range(m):
topk_ids[mi] = torch.randperm(config.num_experts)[:topk]
topk_ids = topk_ids.to(device=torch.cuda.current_device())

return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
Expand Down Expand Up @@ -459,6 +461,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
w2, w1_scale, w2_scale)


TOPKS = [2, 6]
MNKs = [
(1, 128, 2560),
(2, 128, 2560),
Expand All @@ -467,9 +470,16 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
(45, 512, 2560),
(64, 1024, 2560),
(222, 1024, 2560),
(45, 128, 2560),
(64, 128, 2560),
(222, 128, 2560),
(45, 2048, 2560),
(64, 2048, 2560),
(222, 2048, 2560),
(333, 2048, 2560),
(444, 2048, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]
USE_FP8_DISPATCH = [False, True]


@pytest.mark.parametrize("mnk", MNKs)
Expand Down
Loading