|
| 1 | +import pytest |
1 | 2 | import torch
|
2 | 3 | import torch.nn.functional as F
|
3 | 4 | from transformers.activations import get_activation
|
| 5 | + |
4 | 6 | from vllm import activation_ops
|
5 | 7 |
|
| 8 | +DTYPES = [torch.half, torch.bfloat16, torch.float] |
| 9 | +NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing |
| 10 | +D = [512, 4096, 5120, 13824] # Arbitrary values for testing |
| 11 | +SEEDS = [0] |
| 12 | + |
6 | 13 |
|
7 | 14 | def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
8 | 15 | x1, x2 = x.chunk(chunks=2, dim=1)
|
9 | 16 | return F.silu(x1) * x2
|
10 | 17 |
|
11 | 18 |
|
| 19 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 20 | +@pytest.mark.parametrize("d", D) |
| 21 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 22 | +@pytest.mark.parametrize("seed", SEEDS) |
12 | 23 | @torch.inference_mode()
|
13 |
| -def run_silu_and_mul( |
| 24 | +def test_silu_and_mul( |
14 | 25 | num_tokens: int,
|
15 | 26 | d: int,
|
16 | 27 | dtype: torch.dtype,
|
| 28 | + seed: int, |
17 | 29 | ) -> None:
|
| 30 | + torch.random.manual_seed(seed) |
| 31 | + torch.cuda.manual_seed(seed) |
18 | 32 | x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
19 | 33 | out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
20 | 34 | activation_ops.silu_and_mul(out, x)
|
21 | 35 | ref_out = ref_silu_and_mul(x)
|
22 | 36 | assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
23 | 37 |
|
24 | 38 |
|
25 |
| -def test_silu_and_mul() -> None: |
26 |
| - for dtype in [torch.half, torch.bfloat16, torch.float]: |
27 |
| - for num_tokens in [7, 83, 2048]: |
28 |
| - for d in [512, 4096, 5120, 13824]: |
29 |
| - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') |
30 |
| - run_silu_and_mul(num_tokens, d, dtype) |
31 |
| - |
32 |
| - |
| 39 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 40 | +@pytest.mark.parametrize("d", D) |
| 41 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 42 | +@pytest.mark.parametrize("seed", SEEDS) |
33 | 43 | @torch.inference_mode()
|
34 |
| -def run_gelu_new( |
| 44 | +def test_gelu_new( |
35 | 45 | num_tokens: int,
|
36 | 46 | d: int,
|
37 | 47 | dtype: torch.dtype,
|
| 48 | + seed: int, |
38 | 49 | ) -> None:
|
| 50 | + torch.random.manual_seed(seed) |
| 51 | + torch.cuda.manual_seed(seed) |
39 | 52 | x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
40 | 53 | out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
41 | 54 | activation_ops.gelu_new(out, x)
|
42 | 55 | ref_out = get_activation("gelu_new")(x)
|
43 | 56 | assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
44 | 57 |
|
45 | 58 |
|
46 |
| -def test_gelu_new() -> None: |
47 |
| - for dtype in [torch.half, torch.bfloat16, torch.float]: |
48 |
| - for num_tokens in [7, 83, 2048]: |
49 |
| - for d in [512, 4096, 5120, 13824]: |
50 |
| - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') |
51 |
| - run_gelu_new(num_tokens, d, dtype) |
52 |
| - |
53 |
| - |
54 |
| -@torch.inference_mode() |
55 |
| -def run_gelu_fast( |
| 59 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 60 | +@pytest.mark.parametrize("d", D) |
| 61 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 62 | +@pytest.mark.parametrize("seed", SEEDS) |
| 63 | +def test_gelu_fast( |
56 | 64 | num_tokens: int,
|
57 | 65 | d: int,
|
58 | 66 | dtype: torch.dtype,
|
| 67 | + seed: int, |
59 | 68 | ) -> None:
|
| 69 | + torch.random.manual_seed(seed) |
| 70 | + torch.cuda.manual_seed(seed) |
60 | 71 | x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
61 | 72 | out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
62 | 73 | activation_ops.gelu_fast(out, x)
|
63 | 74 | ref_out = get_activation("gelu_fast")(x)
|
64 | 75 | assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
65 |
| - |
66 |
| - |
67 |
| -def test_gelu_fast() -> None: |
68 |
| - for dtype in [torch.half, torch.bfloat16, torch.float]: |
69 |
| - for num_tokens in [7, 83, 2048]: |
70 |
| - for d in [512, 4096, 5120, 13824]: |
71 |
| - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') |
72 |
| - run_gelu_fast(num_tokens, d, dtype) |
|
0 commit comments