Skip to content

[Misc] DeepSeek Decode Optimizations #19807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,32 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
__device__ void _act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
const int d, const int64_t token_idx) {
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
}

// Activation and gating kernel template.

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
_act_and_mul_kernel<scalar_t, ACT_FN, act_first>(out, input, d, token_idx);
}

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
Expand Down Expand Up @@ -223,3 +233,62 @@ void gelu_quick(torch::Tensor& out, // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}

namespace vllm {
// Batched act_and_mul kernel template
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void batched_act_and_mul_kernel(
scalar_t* out, // [B, max_tokens, d]
const scalar_t* input, // [B, max_tokens, 2, d]
const int32_t* valid_tokens_array, // [B]
const int d) {
;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This semicolon is unnecessary and can be removed.

  const int64_t batch_idx = blockIdx.x;

const int64_t batch_idx = blockIdx.x;
const int64_t num_tokens = valid_tokens_array[batch_idx];
if (num_tokens == 0) {
return;
}

const int64_t token_idx = blockIdx.y;
if (token_idx >= num_tokens) {
return;
}

const int64_t max_num_tokens = gridDim.y;
scalar_t* __restrict__ batch_out = &out[batch_idx * max_num_tokens * d];
const scalar_t* __restrict__ batch_input =
&input[batch_idx * max_num_tokens * d * 2];
_act_and_mul_kernel<scalar_t, ACT_FN, act_first>(batch_out, batch_input, d,
token_idx);
}
} // namespace vllm

// Launch batched activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int64_t const batch_size = input.size(0); \
int64_t const max_num_tokens = input.size(1); \
int const d = input.size(2) / 2; \
dim3 grid(batch_size, max_num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "batched_act_and_mul_kernel", [&] { \
vllm::batched_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, \
ACT_FIRST> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
valid_tokens_array.data_ptr<int32_t>(), d); \
});

void batched_silu_and_mul(torch::Tensor& out, // [B, max_tokens, d]
torch::Tensor& input, // [B, max_tokens, 2, d]
torch::Tensor& valid_tokens_array) // [B]
{
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(valid_tokens_array.dtype() == torch::kInt32);
LAUNCH_BATCHED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void batched_silu_and_mul(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& valid_tokens_array);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
5 changes: 5 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);

ops.def(
"batched_silu_and_mul(Tensor! result, Tensor input, Tensor "
"valid_tokens_array) -> ()");
ops.impl("batched_silu_and_mul", torch::kCUDA, &batched_silu_and_mul);

ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

Expand Down
54 changes: 54 additions & 0 deletions tests/kernels/core/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
SiluAndMul)
from vllm.platforms import current_platform


def ref_batched_silu_mul(x, out, valid_tokens_array):
"""
Reference implementation of batched silu_and_mul
"""
valid_tokens_array = valid_tokens_array.to("cpu")
batch_size = x.size(0)
for b in range(batch_size):
# num valid tokens
n = valid_tokens_array[b]
if n == 0:
continue
torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :])


DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
Expand Down Expand Up @@ -106,3 +121,42 @@ def test_activation(

out = torch.empty_like(x)
opcheck(fn, (out, x))


## Test Batched Implementaion ####

BATCH_SIZES = [1, 13, 26, 32]
NUM_TOKENS = [7, 37, 64, 4096]
D = [128, 256, 384, 512, 1024, 13824]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_batched_silu_mul(batch_size: int, num_tokens: int, d: int,
dtype: torch.dtype):

input = torch.randn(
(batch_size, num_tokens, d), device="cuda", dtype=dtype) / 10.0

out = torch.empty((batch_size, num_tokens, d // 2),
device="cuda",
dtype=dtype)

ref_out = out.clone()

# valid num_tokens per batch
valid_num_tokens = torch.randint(low=0,
high=num_tokens + 1,
size=(batch_size, ),
device="cuda").to(dtype=torch.int32)

# reference
ref_batched_silu_mul(input, ref_out, valid_num_tokens)

# impl
torch.ops._C.batched_silu_and_mul(out, input, valid_num_tokens)

torch.testing.assert_close(ref_out, out)
154 changes: 154 additions & 0 deletions tests/kernels/moe/test_masked_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test for masked utility kernels.
"""

import pytest
import torch

from vllm.model_executor.layers.fused_moe.masked_kernels import (
invoke_masked_silu_and_mul, masked_per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform


def ref_silu_mul(x, out, valid_tokens_array):

valid_tokens_array = valid_tokens_array.to("cpu")
batch_size = x.size(0)
for b in range(batch_size):
# num valid tokens
n = valid_tokens_array[b]
if n == 0:
continue
torch.ops._C.silu_and_mul(out[b, :n, :], x[b, :n, :])


def ref_per_token_group_quant(
x: torch.Tensor, x_q: torch.Tensor, valid_tokens_array: torch.Tensor,
group_size: int,
column_major_scales: bool) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape == x_q.shape

# make scales tensor
B, NUM_TOKENS, HIDDEN_SIZE = x.shape
x_q_s = torch.empty((B, NUM_TOKENS, HIDDEN_SIZE // group_size),
device="cuda",
dtype=torch.float32)

valid_tokens_array = valid_tokens_array.to("cpu")
batch_size = x.size(0)
for b in range(batch_size):
# num valid tokens
n = valid_tokens_array[b]
if n == 0:
continue
x_slice = x[b, :n, :]
xq_slice, xqs_slice = per_token_group_quant_fp8(
x_slice, group_size, column_major_scales=column_major_scales)
x_q[b, :n, :].copy_(xq_slice)
x_q_s[b, :n, :].copy_(xqs_slice)

return x_q, x_q_s


BATCH_SIZES = [1, 13, 26, 32]
NUM_TOKENS = [7, 37, 64, 4096]

## Tests for masked per_token_group_quant_fp8 ####

HIDDEN_SIZES = [128, 256, 384, 512, 1024]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("column_major_scales", [True])
def test_masked_per_token_group_quant_fp8(batch_size: int, num_tokens: int,
hidden_size: int, dtype: torch.dtype,
column_major_scales: bool):

DEEPGEMM_BLOCK_SIZE = 128

input = torch.randn(
(batch_size, num_tokens, hidden_size), device="cuda",
dtype=dtype) / 10.0

out_q = torch.randn((batch_size, num_tokens, hidden_size), device="cuda")
out_q = out_q.to(dtype=current_platform.fp8_dtype())

ref_out_q = torch.empty_like(out_q)
ref_out_q.copy_(out_q)

# valid num_tokens per batch
valid_num_tokens = torch.randint(low=0,
high=num_tokens + 1,
size=(batch_size, ),
device="cuda").to(torch.int32)

# Reference
ref_out_q, ref_out_scales = ref_per_token_group_quant(
x=input,
x_q=ref_out_q,
valid_tokens_array=valid_num_tokens,
group_size=DEEPGEMM_BLOCK_SIZE,
column_major_scales=column_major_scales)

# Impl
out_q, out_scales = masked_per_token_group_quant_fp8(
x=input,
x_q=out_q,
valid_tokens_array=valid_num_tokens,
group_size=DEEPGEMM_BLOCK_SIZE,
column_major_scales=column_major_scales)

torch.testing.assert_close(ref_out_q, out_q)

valid_num_tokens_cpu = valid_num_tokens.to(device="cpu")
for b in range(valid_num_tokens_cpu.size(0)):
n = valid_num_tokens_cpu[b]
torch.testing.assert_close(ref_out_scales[b, :n, :],
out_scales[b, :n, :])


## Tests for masked silu_and_mul ####

HIDDEN_SIZES = [124, 1024, 2176, 2816]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_masked_silu_mul(batch_size: int, num_tokens: int, hidden_size: int,
dtype: torch.dtype):

input = torch.randn(
(batch_size, num_tokens, hidden_size), device="cuda",
dtype=dtype) / 10.0

out = torch.empty((batch_size, num_tokens, hidden_size // 2),
device="cuda",
dtype=dtype)

ref_out = torch.empty_like(out)
ref_out.copy_(out)

# valid num_tokens per batch
valid_num_tokens = torch.randint(low=0,
high=num_tokens + 1,
size=(batch_size, ),
device="cuda").to(torch.int32)

# reference
ref_silu_mul(input, ref_out, valid_num_tokens)

# impl
invoke_masked_silu_and_mul(out, input, valid_num_tokens)

torch.testing.assert_close(ref_out, out)
Loading
Loading