Skip to content

Commit 9b29497

Browse files
authored
Add PyTorch-native implementation of custom layers (#1898)
1 parent 5313c2c commit 9b29497

File tree

6 files changed

+150
-185
lines changed

6 files changed

+150
-185
lines changed

tests/kernels/test_activation.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
import pytest
22
import torch
3-
import torch.nn.functional as F
4-
from transformers.activations import get_activation
53

6-
from vllm._C import ops
4+
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
75

86
DTYPES = [torch.half, torch.bfloat16, torch.float]
97
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
108
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
119
SEEDS = [0]
1210

1311

14-
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
15-
x1, x2 = x.chunk(chunks=2, dim=1)
16-
return F.silu(x1) * x2
17-
18-
1912
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
2013
@pytest.mark.parametrize("d", D)
2114
@pytest.mark.parametrize("dtype", DTYPES)
@@ -30,9 +23,9 @@ def test_silu_and_mul(
3023
torch.random.manual_seed(seed)
3124
torch.cuda.manual_seed(seed)
3225
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
33-
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
34-
ops.silu_and_mul(out, x)
35-
ref_out = ref_silu_and_mul(x)
26+
layer = SiluAndMul()
27+
out = layer(x)
28+
ref_out = layer._forward(x)
3629
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
3730

3831

@@ -50,9 +43,9 @@ def test_gelu_new(
5043
torch.random.manual_seed(seed)
5144
torch.cuda.manual_seed(seed)
5245
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
53-
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
54-
ops.gelu_new(out, x)
55-
ref_out = get_activation("gelu_new")(x)
46+
layer = NewGELU()
47+
out = layer(x)
48+
ref_out = layer._forward(x)
5649
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
5750

5851

@@ -69,7 +62,7 @@ def test_gelu_fast(
6962
torch.random.manual_seed(seed)
7063
torch.cuda.manual_seed(seed)
7164
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
72-
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
73-
ops.gelu_fast(out, x)
74-
ref_out = get_activation("gelu_fast")(x)
65+
layer = FastGELU()
66+
out = layer(x)
67+
ref_out = layer._forward(x)
7568
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

tests/kernels/test_layernorm.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,47 @@
11
import pytest
22
import torch
3-
import torch.nn as nn
43

5-
from vllm._C import ops
4+
from vllm.model_executor.layers.layernorm import RMSNorm
65

76
DTYPES = [torch.half, torch.bfloat16, torch.float]
8-
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
97
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
8+
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
9+
ADD_RESIDUAL = [False, True]
1010
SEEDS = [0]
1111

1212

13-
class RefRMSNorm(nn.Module):
14-
15-
def __init__(self, hidden_size, eps=1e-6):
16-
super().__init__()
17-
weight = torch.empty(hidden_size)
18-
weight.normal_(mean=1.0, std=0.1)
19-
self.weight = nn.Parameter(weight)
20-
self.variance_epsilon = eps
21-
22-
def forward(self, hidden_states):
23-
input_dtype = hidden_states.dtype
24-
hidden_states = hidden_states.to(torch.float32)
25-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
26-
hidden_states = hidden_states * torch.rsqrt(variance +
27-
self.variance_epsilon)
28-
return self.weight * hidden_states.to(input_dtype)
29-
30-
3113
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
3214
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
15+
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
3316
@pytest.mark.parametrize("dtype", DTYPES)
3417
@pytest.mark.parametrize("seed", SEEDS)
3518
@torch.inference_mode()
3619
def test_rms_norm(
3720
num_tokens: int,
3821
hidden_size: int,
22+
add_residual: bool,
3923
dtype: torch.dtype,
4024
seed: int,
4125
) -> None:
4226
torch.random.manual_seed(seed)
4327
torch.cuda.manual_seed(seed)
4428

45-
scale = float(hidden_size**-0.5)
46-
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
47-
x.uniform_(-scale, scale)
48-
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
49-
50-
out = torch.empty_like(x)
51-
ops.rms_norm(
52-
out,
53-
x,
54-
ref.weight.data,
55-
ref.variance_epsilon,
56-
)
57-
ref_out = ref(x)
58-
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
29+
layer = RMSNorm(hidden_size).to(dtype).cuda()
30+
layer.weight.data.normal_(mean=1.0, std=0.1)
31+
scale = 1 / (2 * hidden_size)
32+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda")
33+
x *= scale
34+
residual = torch.randn_like(x) * scale if add_residual else None
35+
36+
# NOTE(woosuk): The reference implementation should be executed first
37+
# because the custom kernel is in-place.
38+
ref_out = layer._forward(x, residual)
39+
out = layer(x, residual)
40+
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
41+
# numerical errors than other operators because they involve reductions.
42+
# Therefore, we use a larger tolerance.
43+
if add_residual:
44+
assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
45+
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
46+
else:
47+
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)

tests/kernels/test_pos_encoding.py

Lines changed: 24 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,23 @@
1-
from typing import Optional, Tuple
1+
from typing import Optional
22

33
import pytest
44
import torch
5-
import torch.nn as nn
6-
import torch.nn.functional as F
75

8-
from vllm._C import ops
6+
from vllm.model_executor.layers.rotary_embedding import get_rope
97

108
IS_NEOX_STYLE = [True, False]
119
DTYPES = [torch.half, torch.bfloat16, torch.float]
1210
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
1311
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
14-
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
15-
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
12+
NUM_HEADS = [7, 17] # Arbitrary values for testing
13+
BATCH_SIZES = [1, 5] # Arbitrary values for testing
14+
SEQ_LENS = [11, 8192] # Arbitrary values for testing
1615
SEEDS = [0]
1716

1817

19-
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
20-
x1 = x[..., :x.shape[-1] // 2]
21-
x2 = x[..., x.shape[-1] // 2:]
22-
return torch.cat((-x2, x1), dim=-1)
23-
24-
25-
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
26-
x1 = x[..., ::2]
27-
x2 = x[..., 1::2]
28-
x = torch.stack((-x2, x1), dim=-1)
29-
return x.flatten(-2)
30-
31-
32-
def apply_rope(
33-
q: torch.Tensor,
34-
k: torch.Tensor,
35-
cos: torch.Tensor,
36-
sin: torch.Tensor,
37-
is_neox_style: bool,
38-
) -> Tuple[torch.Tensor, torch.Tensor]:
39-
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
40-
q_embed = (q * cos) + (rotate_fn(q) * sin)
41-
k_embed = (k * cos) + (rotate_fn(k) * sin)
42-
return q_embed, k_embed
43-
44-
45-
class RefRotaryEmbedding(nn.Module):
46-
"""Reference implementation of rotary embedding."""
47-
48-
def __init__(
49-
self,
50-
dim: int,
51-
is_neox_style: bool,
52-
max_position_embeddings: int = 8192,
53-
base: int = 10000,
54-
) -> None:
55-
super().__init__()
56-
self.rotary_dim = dim
57-
self.is_neox_style = is_neox_style
58-
self.max_position_embeddings = max_position_embeddings
59-
60-
# Create cos and sin embeddings.
61-
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
62-
t = torch.arange(max_position_embeddings).float()
63-
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
64-
if is_neox_style:
65-
emb = torch.cat((freqs, freqs), dim=-1)
66-
else:
67-
emb = torch.repeat_interleave(freqs, 2, -1)
68-
cos = emb.cos().to(dtype=inv_freq.dtype)
69-
sin = emb.sin().to(dtype=inv_freq.dtype)
70-
self.register_buffer("cos_cached", cos, persistent=False)
71-
self.register_buffer("sin_cached", sin, persistent=False)
72-
73-
def forward(
74-
self,
75-
positions: torch.Tensor, # [num_tokens]
76-
query: torch.Tensor, # [num_tokens, num_heads, head_size]
77-
key: torch.Tensor, # [num_tokens, num_heads, head_size]
78-
) -> Tuple[torch.Tensor, torch.Tensor]:
79-
query_rot = query[..., :self.rotary_dim]
80-
query_pass = query[..., self.rotary_dim:]
81-
key_rot = key[..., :self.rotary_dim]
82-
key_pass = key[..., self.rotary_dim:]
83-
84-
query_rot = query_rot.transpose(0, 1)
85-
key_rot = key_rot.transpose(0, 1)
86-
cos = F.embedding(positions, self.cos_cached)
87-
sin = F.embedding(positions, self.sin_cached)
88-
89-
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
90-
self.is_neox_style)
91-
query_rot = query_rot.transpose(0, 1).contiguous()
92-
key_rot = key_rot.transpose(0, 1).contiguous()
93-
94-
query = torch.cat((query_rot, query_pass), dim=-1)
95-
key = torch.cat((key_rot, key_pass), dim=-1)
96-
97-
# Output query/key shape: [num_tokens, num_tokens, head_size]
98-
return query, key
99-
100-
10118
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
102-
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
19+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
20+
@pytest.mark.parametrize("seq_len", SEQ_LENS)
10321
@pytest.mark.parametrize("num_heads", NUM_HEADS)
10422
@pytest.mark.parametrize("head_size", HEAD_SIZES)
10523
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@@ -108,7 +26,8 @@ def forward(
10826
@torch.inference_mode()
10927
def test_rotary_embedding(
11028
is_neox_style: bool,
111-
num_tokens: int,
29+
batch_size: int,
30+
seq_len: int,
11231
num_heads: int,
11332
head_size: int,
11433
rotary_dim: Optional[int],
@@ -122,53 +41,25 @@ def test_rotary_embedding(
12241
torch.random.manual_seed(seed)
12342
torch.cuda.manual_seed(seed)
12443

125-
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
126-
query = torch.randn(num_tokens,
44+
if rotary_dim is None:
45+
rotary_dim = head_size
46+
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
47+
rope = rope.to(dtype).cuda()
48+
49+
positions = torch.randint(0,
50+
max_position, (batch_size, seq_len),
51+
device="cuda")
52+
query = torch.randn(batch_size,
53+
seq_len,
12754
num_heads * head_size,
12855
dtype=dtype,
12956
device="cuda")
130-
key = torch.randn(num_tokens,
131-
num_heads * head_size,
132-
dtype=dtype,
133-
device="cuda")
134-
135-
# Create the rotary embedding.
136-
inv_freq = 1.0 / (base**(
137-
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
138-
t = torch.arange(max_position).float()
139-
freqs = torch.einsum("i,j -> ij", t, inv_freq)
140-
cos = freqs.cos()
141-
sin = freqs.sin()
142-
cos_sin_cache = torch.cat((cos, sin), dim=-1)
143-
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
144-
145-
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
146-
out_query = query.clone()
147-
out_key = key.clone()
148-
ops.rotary_embedding(
149-
positions,
150-
out_query,
151-
out_key,
152-
head_size,
153-
cos_sin_cache,
154-
is_neox_style,
155-
)
156-
157-
# Run the reference implementation.
158-
ref_rotary_embedding = RefRotaryEmbedding(
159-
dim=rotary_dim,
160-
is_neox_style=is_neox_style,
161-
max_position_embeddings=max_position,
162-
base=base,
163-
).to(dtype=dtype, device="cuda")
164-
ref_query, ref_key = ref_rotary_embedding(
165-
positions,
166-
query.view(num_tokens, num_heads, head_size),
167-
key.view(num_tokens, num_heads, head_size),
168-
)
169-
ref_query = ref_query.view(num_tokens, num_heads * head_size)
170-
ref_key = ref_key.view(num_tokens, num_heads * head_size)
57+
key = torch.randn_like(query)
17158

59+
# NOTE(woosuk): The reference implementation should be executed first
60+
# because the custom kernel is in-place.
61+
ref_query, ref_key = rope._forward(positions, query, key)
62+
out_query, out_key = rope.forward(positions, query, key)
17263
# Compare the results.
17364
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
17465
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)

vllm/model_executor/layers/activation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Custom activation functions."""
2+
import math
23
from typing import Optional
34

45
import torch
56
import torch.nn as nn
7+
import torch.nn.functional as F
68

79
from vllm._C import ops
810
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
2224
return: (batch_size, seq_len, d) or (num_tokens, d)
2325
"""
2426

27+
def _forward(self, x: torch.Tensor) -> torch.Tensor:
28+
"""PyTorch-native implementation equivalent to forward()."""
29+
d = x.shape[-1] // 2
30+
return F.silu(x[..., :d]) * x[..., d:]
31+
2532
def forward(self, x: torch.Tensor) -> torch.Tensor:
2633
d = x.shape[-1] // 2
2734
output_shape = (x.shape[:-1] + (d, ))
@@ -32,6 +39,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3239

3340
class NewGELU(nn.Module):
3441

42+
def _forward(self, x: torch.Tensor) -> torch.Tensor:
43+
"""PyTorch-native implementation equivalent to forward()."""
44+
c = math.sqrt(2.0 / math.pi)
45+
return 0.5 * x * (1.0 + torch.tanh(c *
46+
(x + 0.044715 * torch.pow(x, 3.0))))
47+
3548
def forward(self, x: torch.Tensor) -> torch.Tensor:
3649
out = torch.empty_like(x)
3750
ops.gelu_new(out, x)
@@ -40,6 +53,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4053

4154
class FastGELU(nn.Module):
4255

56+
def _forward(self, x: torch.Tensor) -> torch.Tensor:
57+
"""PyTorch-native implementation equivalent to forward()."""
58+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
59+
(1.0 + 0.044715 * x * x)))
60+
4361
def forward(self, x: torch.Tensor) -> torch.Tensor:
4462
out = torch.empty_like(x)
4563
ops.gelu_fast(out, x)

0 commit comments

Comments
 (0)