Skip to content

Commit 6f1df4f

Browse files
Add silu_and_mul kernel (#8)
* base of silu_and_mul Signed-off-by: Ma, Liangliang <[email protected]> * refine tests Signed-off-by: Ma, Liangliang <[email protected]> * rm redundant cast Signed-off-by: Ma, Liangliang <[email protected]> * ut pass Signed-off-by: Ma, Liangliang <[email protected]> * fix acc issue Signed-off-by: Ma, Liangliang <[email protected]> * fix format Signed-off-by: Ma, Liangliang <[email protected]> * fix format2 Signed-off-by: Ma, Liangliang <[email protected]> --------- Signed-off-by: Ma, Liangliang <[email protected]>
1 parent 5d63fd4 commit 6f1df4f

File tree

8 files changed

+165
-0
lines changed

8 files changed

+165
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
148148
set(VLLM_EXT_SRC
149149
"csrc/xpu/cache.cpp"
150150
"csrc/xpu/layernorm.cpp"
151+
"csrc/xpu/activation.cpp"
151152
"csrc/xpu/pos_encoding_kernels.cpp"
152153
"csrc/xpu/torch_bindings.cpp"
153154
)

csrc/xpu/activation.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <sycl/sycl.hpp>
2+
3+
#include <algorithm>
4+
#include "utils.h"
5+
#include "dispatch_utils.h"
6+
7+
namespace vllm {
8+
9+
template <typename T>
10+
inline T silu_kernel(const T& x) {
11+
// x * sigmoid(x)
12+
return (T)(((float)x) / (1.0f + sycl::exp((float)-x)));
13+
}
14+
15+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
16+
bool act_first>
17+
inline scalar_t compute(const scalar_t& x, const scalar_t& y) {
18+
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
19+
}
20+
21+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
22+
bool act_first>
23+
void act_and_mul_kernel(scalar_t* __restrict__ out, // [..., d]
24+
const scalar_t* __restrict__ input, // [..., 2, d]
25+
const int d, const sycl::nd_item<3>& item_ct1) {
26+
const int64_t token_idx = item_ct1.get_group(2);
27+
for (int64_t idx = item_ct1.get_local_id(2); idx < d;
28+
idx += item_ct1.get_local_range(2)) {
29+
const scalar_t x = input[token_idx * 2 * d + idx];
30+
const scalar_t y = input[token_idx * 2 * d + d + idx];
31+
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
32+
}
33+
}
34+
35+
template <typename scalar_t>
36+
void call_silu_and_mul_kernel(torch::Tensor& out, torch::Tensor& input) {
37+
using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type;
38+
int d = input.size(-1) / 2;
39+
int64_t num_tokens = input.numel() / input.size(-1);
40+
// dpct::dim3 grid(num_tokens);
41+
// dpct::dim3 block(std::min(d, 1024));
42+
sycl::range<3> grid(1, 1, num_tokens);
43+
sycl::range<3> block(1, 1, std::min(d, 1024));
44+
if (num_tokens == 0) {
45+
return;
46+
}
47+
auto out_ptr = out.data_ptr<scalar_t>();
48+
auto input_ptr = input.data_ptr<scalar_t>();
49+
at::DeviceGuard device_guard(input.device());
50+
auto& queue = vllm::xpu::vllmGetQueue();
51+
queue.submit([&](sycl::handler& cgh) {
52+
cgh.parallel_for(
53+
sycl::nd_range<3>(grid * block, block),
54+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
55+
act_and_mul_kernel<sycl_t, silu_kernel, true>(
56+
(sycl_t*)out_ptr, (sycl_t*)input_ptr, d, item_ct1);
57+
});
58+
});
59+
}
60+
61+
} // namespace vllm
62+
63+
void silu_and_mul(torch::Tensor& out, // [..., d]
64+
torch::Tensor& input) // [..., 2 * d]
65+
{
66+
VLLM_DISPATCH_FLOATING_TYPES(
67+
input.scalar_type(), "call_silu_and_mul_kernel",
68+
[&] { vllm::call_silu_and_mul_kernel<scalar_t>(out, input); });
69+
}

csrc/xpu/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
88
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
99
torch::Tensor& weight, double epsilon);
1010

11+
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
12+
1113
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
1214
std::optional<torch::Tensor> key, int64_t head_size,
1315
torch::Tensor& cos_sin_cache, bool is_neox);

csrc/xpu/torch_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
3232
"float epsilon) -> ()");
3333
ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm);
3434

35+
// activation ops
36+
ops.def("silu_and_mul(Tensor! out, Tensor! input) -> ()");
37+
ops.impl("silu_and_mul", torch::kXPU, &silu_and_mul);
38+
3539
// pos_embedding
3640
ops.def(
3741
"rotary_embedding(Tensor positions, Tensor! query,"

tests/ops/activation_op.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import tests.register_ops as ops
6+
from tests.ops.custom_ops import CustomOp
7+
8+
9+
class SiluAndMul(CustomOp):
10+
"""An activation function for SwiGLU.
11+
12+
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
13+
14+
Shapes:
15+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
16+
return: (num_tokens, d) or (batch_size, seq_len, d)
17+
"""
18+
19+
def __init__(self):
20+
super().__init__()
21+
self.op = ops.silu_and_mul
22+
23+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
24+
"""PyTorch-native implementation equivalent to forward()."""
25+
d = x.shape[-1] // 2
26+
return F.silu(x[..., :d]) * x[..., d:]
27+
28+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
29+
d = x.shape[-1] // 2
30+
output_shape = (x.shape[:-1] + (d, ))
31+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
32+
self.op(out, x)
33+
return out

tests/register_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
2020
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
2121

2222

23+
def silu_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
24+
torch.ops._C.silu_and_mul(out, input)
25+
26+
2327
def rotary_embedding(
2428
positions: torch.Tensor,
2529
query: torch.Tensor,

tests/test_activation.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
import torch
4+
5+
from tests.ops.activation_op import SiluAndMul
6+
from tests.utils import opcheck, seed_everything
7+
8+
DTYPES = [torch.half, torch.bfloat16, torch.float]
9+
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
10+
D = [512, 13824] # Arbitrary values for testing
11+
SEEDS = [0]
12+
XPU_DEVICES = [
13+
f"xpu:{i}" for i in range(1 if torch.xpu.device_count() == 1 else 2)
14+
]
15+
16+
17+
@pytest.mark.parametrize("activation", ["silu_and_mul"])
18+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
19+
@pytest.mark.parametrize("d", D)
20+
@pytest.mark.parametrize("dtype", DTYPES)
21+
@pytest.mark.parametrize("seed", SEEDS)
22+
@pytest.mark.parametrize("device", XPU_DEVICES)
23+
@torch.inference_mode()
24+
def test_act_and_mul(
25+
activation: str,
26+
num_tokens: int,
27+
d: int,
28+
dtype: torch.dtype,
29+
seed: int,
30+
device: str,
31+
) -> None:
32+
seed_everything(seed)
33+
torch.set_default_device(device)
34+
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
35+
if activation == "silu_and_mul":
36+
layer = SiluAndMul()
37+
fn = torch.ops._C.silu_and_mul
38+
out = layer(x)
39+
ref_out = layer.forward_native(x)
40+
41+
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-3)
42+
43+
d = x.shape[-1] // 2
44+
output_shape = (x.shape[:-1] + (d, ))
45+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
46+
opcheck(fn, (out, x))

tests/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ def opcheck(
7878
}
7979

8080

81+
def seed_everything(seed) -> None:
82+
random.seed(seed)
83+
np.random.seed(seed)
84+
torch.manual_seed(seed)
85+
86+
8187
def _convert_from_fp8(
8288
tensor: torch.Tensor,
8389
scale: float = 1.0,

0 commit comments

Comments
 (0)