Skip to content

Commit fbd80ad

Browse files
authored
Clean up kernel unit tests (#938)
1 parent 22379d5 commit fbd80ad

File tree

6 files changed

+364
-399
lines changed

6 files changed

+364
-399
lines changed

tests/kernels/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List, Tuple
2+
3+
import pytest
4+
import torch
5+
6+
7+
def create_kv_caches(
8+
num_blocks: int,
9+
block_size: int,
10+
num_layers: int,
11+
num_heads: int,
12+
head_size: int,
13+
dtype: torch.dtype,
14+
seed: int,
15+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
16+
torch.random.manual_seed(seed)
17+
torch.cuda.manual_seed(seed)
18+
19+
scale = head_size**-0.5
20+
x = 16 // torch.tensor([], dtype=dtype).element_size()
21+
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
22+
key_caches = []
23+
for _ in range(num_layers):
24+
key_cache = torch.empty(size=key_cache_shape,
25+
dtype=dtype,
26+
device='cuda')
27+
key_cache.uniform_(-scale, scale)
28+
key_caches.append(key_cache)
29+
30+
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
31+
value_caches = []
32+
for _ in range(num_layers):
33+
value_cache = torch.empty(size=value_cache_shape,
34+
dtype=dtype,
35+
device='cuda')
36+
value_cache.uniform_(-scale, scale)
37+
value_caches.append(value_cache)
38+
return key_caches, value_caches
39+
40+
41+
@pytest.fixture()
42+
def kv_cache_factory():
43+
return create_kv_caches

tests/kernels/test_activation.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,75 @@
1+
import pytest
12
import torch
23
import torch.nn.functional as F
34
from transformers.activations import get_activation
5+
46
from vllm import activation_ops
57

8+
DTYPES = [torch.half, torch.bfloat16, torch.float]
9+
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
10+
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
11+
SEEDS = [0]
12+
613

714
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
815
x1, x2 = x.chunk(chunks=2, dim=1)
916
return F.silu(x1) * x2
1017

1118

19+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
20+
@pytest.mark.parametrize("d", D)
21+
@pytest.mark.parametrize("dtype", DTYPES)
22+
@pytest.mark.parametrize("seed", SEEDS)
1223
@torch.inference_mode()
13-
def run_silu_and_mul(
24+
def test_silu_and_mul(
1425
num_tokens: int,
1526
d: int,
1627
dtype: torch.dtype,
28+
seed: int,
1729
) -> None:
30+
torch.random.manual_seed(seed)
31+
torch.cuda.manual_seed(seed)
1832
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
1933
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
2034
activation_ops.silu_and_mul(out, x)
2135
ref_out = ref_silu_and_mul(x)
2236
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
2337

2438

25-
def test_silu_and_mul() -> None:
26-
for dtype in [torch.half, torch.bfloat16, torch.float]:
27-
for num_tokens in [7, 83, 2048]:
28-
for d in [512, 4096, 5120, 13824]:
29-
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
30-
run_silu_and_mul(num_tokens, d, dtype)
31-
32-
39+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
40+
@pytest.mark.parametrize("d", D)
41+
@pytest.mark.parametrize("dtype", DTYPES)
42+
@pytest.mark.parametrize("seed", SEEDS)
3343
@torch.inference_mode()
34-
def run_gelu_new(
44+
def test_gelu_new(
3545
num_tokens: int,
3646
d: int,
3747
dtype: torch.dtype,
48+
seed: int,
3849
) -> None:
50+
torch.random.manual_seed(seed)
51+
torch.cuda.manual_seed(seed)
3952
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
4053
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
4154
activation_ops.gelu_new(out, x)
4255
ref_out = get_activation("gelu_new")(x)
4356
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
4457

4558

46-
def test_gelu_new() -> None:
47-
for dtype in [torch.half, torch.bfloat16, torch.float]:
48-
for num_tokens in [7, 83, 2048]:
49-
for d in [512, 4096, 5120, 13824]:
50-
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
51-
run_gelu_new(num_tokens, d, dtype)
52-
53-
54-
@torch.inference_mode()
55-
def run_gelu_fast(
59+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
60+
@pytest.mark.parametrize("d", D)
61+
@pytest.mark.parametrize("dtype", DTYPES)
62+
@pytest.mark.parametrize("seed", SEEDS)
63+
def test_gelu_fast(
5664
num_tokens: int,
5765
d: int,
5866
dtype: torch.dtype,
67+
seed: int,
5968
) -> None:
69+
torch.random.manual_seed(seed)
70+
torch.cuda.manual_seed(seed)
6071
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
6172
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
6273
activation_ops.gelu_fast(out, x)
6374
ref_out = get_activation("gelu_fast")(x)
6475
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
65-
66-
67-
def test_gelu_fast() -> None:
68-
for dtype in [torch.half, torch.bfloat16, torch.float]:
69-
for num_tokens in [7, 83, 2048]:
70-
for d in [512, 4096, 5120, 13824]:
71-
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
72-
run_gelu_fast(num_tokens, d, dtype)

0 commit comments

Comments
 (0)