|
| 1 | +import pytest |
1 | 2 | import torch |
2 | 3 | import torch.nn.functional as F |
3 | 4 | from transformers.activations import get_activation |
| 5 | + |
4 | 6 | from vllm import activation_ops |
5 | 7 |
|
| 8 | +DTYPES = [torch.half, torch.bfloat16, torch.float] |
| 9 | +NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing |
| 10 | +D = [512, 4096, 5120, 13824] # Arbitrary values for testing |
| 11 | +SEEDS = [0] |
| 12 | + |
6 | 13 |
|
7 | 14 | def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: |
8 | 15 | x1, x2 = x.chunk(chunks=2, dim=1) |
9 | 16 | return F.silu(x1) * x2 |
10 | 17 |
|
11 | 18 |
|
| 19 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 20 | +@pytest.mark.parametrize("d", D) |
| 21 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 22 | +@pytest.mark.parametrize("seed", SEEDS) |
12 | 23 | @torch.inference_mode() |
13 | | -def run_silu_and_mul( |
| 24 | +def test_silu_and_mul( |
14 | 25 | num_tokens: int, |
15 | 26 | d: int, |
16 | 27 | dtype: torch.dtype, |
| 28 | + seed: int, |
17 | 29 | ) -> None: |
| 30 | + torch.random.manual_seed(seed) |
| 31 | + torch.cuda.manual_seed(seed) |
18 | 32 | x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') |
19 | 33 | out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') |
20 | 34 | activation_ops.silu_and_mul(out, x) |
21 | 35 | ref_out = ref_silu_and_mul(x) |
22 | 36 | assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) |
23 | 37 |
|
24 | 38 |
|
25 | | -def test_silu_and_mul() -> None: |
26 | | - for dtype in [torch.half, torch.bfloat16, torch.float]: |
27 | | - for num_tokens in [7, 83, 2048]: |
28 | | - for d in [512, 4096, 5120, 13824]: |
29 | | - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') |
30 | | - run_silu_and_mul(num_tokens, d, dtype) |
31 | | - |
32 | | - |
| 39 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 40 | +@pytest.mark.parametrize("d", D) |
| 41 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 42 | +@pytest.mark.parametrize("seed", SEEDS) |
33 | 43 | @torch.inference_mode() |
34 | | -def run_gelu_new( |
| 44 | +def test_gelu_new( |
35 | 45 | num_tokens: int, |
36 | 46 | d: int, |
37 | 47 | dtype: torch.dtype, |
| 48 | + seed: int, |
38 | 49 | ) -> None: |
| 50 | + torch.random.manual_seed(seed) |
| 51 | + torch.cuda.manual_seed(seed) |
39 | 52 | x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') |
40 | 53 | out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') |
41 | 54 | activation_ops.gelu_new(out, x) |
42 | 55 | ref_out = get_activation("gelu_new")(x) |
43 | 56 | assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) |
44 | 57 |
|
45 | 58 |
|
46 | | -def test_gelu_new() -> None: |
47 | | - for dtype in [torch.half, torch.bfloat16, torch.float]: |
48 | | - for num_tokens in [7, 83, 2048]: |
49 | | - for d in [512, 4096, 5120, 13824]: |
50 | | - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') |
51 | | - run_gelu_new(num_tokens, d, dtype) |
52 | | - |
53 | | - |
54 | | -@torch.inference_mode() |
55 | | -def run_gelu_fast( |
| 59 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 60 | +@pytest.mark.parametrize("d", D) |
| 61 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 62 | +@pytest.mark.parametrize("seed", SEEDS) |
| 63 | +def test_gelu_fast( |
56 | 64 | num_tokens: int, |
57 | 65 | d: int, |
58 | 66 | dtype: torch.dtype, |
| 67 | + seed: int, |
59 | 68 | ) -> None: |
| 69 | + torch.random.manual_seed(seed) |
| 70 | + torch.cuda.manual_seed(seed) |
60 | 71 | x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') |
61 | 72 | out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') |
62 | 73 | activation_ops.gelu_fast(out, x) |
63 | 74 | ref_out = get_activation("gelu_fast")(x) |
64 | 75 | assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) |
65 | | - |
66 | | - |
67 | | -def test_gelu_fast() -> None: |
68 | | - for dtype in [torch.half, torch.bfloat16, torch.float]: |
69 | | - for num_tokens in [7, 83, 2048]: |
70 | | - for d in [512, 4096, 5120, 13824]: |
71 | | - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') |
72 | | - run_gelu_fast(num_tokens, d, dtype) |
|
0 commit comments