|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | +from torch.library import Library |
| 6 | + |
| 7 | +from vllm.compilation.counter import compilation_counter |
| 8 | +from vllm.compilation.decorators import (ignore_torch_compile, |
| 9 | + support_torch_compile) |
| 10 | +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, |
| 11 | + VllmConfig, set_current_vllm_config) |
| 12 | +from vllm.forward_context import set_forward_context |
| 13 | +from vllm.utils import direct_register_custom_op |
| 14 | + |
| 15 | +# create a library to hold the custom op |
| 16 | +silly_lib = Library("silly", "FRAGMENT") # noqa |
| 17 | + |
| 18 | +BATCH_SIZE = 32 |
| 19 | +MLP_SIZE = 128 |
| 20 | + |
| 21 | + |
| 22 | +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
| 23 | + out: torch.Tensor) -> None: |
| 24 | + out.copy_(q) |
| 25 | + out += k |
| 26 | + out += v |
| 27 | + |
| 28 | + |
| 29 | +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
| 30 | + out: torch.Tensor) -> None: |
| 31 | + return |
| 32 | + |
| 33 | + |
| 34 | +direct_register_custom_op( |
| 35 | + op_name="attention", |
| 36 | + op_func=silly_attention, |
| 37 | + mutates_args=["out"], |
| 38 | + fake_impl=silly_attention_fake, |
| 39 | + target_lib=silly_lib, |
| 40 | +) |
| 41 | + |
| 42 | + |
| 43 | +def test_ignore_torch_compile_decorator(): |
| 44 | + # piecewise |
| 45 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 46 | + level=CompilationLevel.PIECEWISE, |
| 47 | + use_cudagraph=True, |
| 48 | + splitting_ops=["silly.attention"], |
| 49 | + cudagraph_capture_sizes=[1, 2], |
| 50 | + )) |
| 51 | + |
| 52 | + @support_torch_compile |
| 53 | + class A(nn.Module): |
| 54 | + |
| 55 | + def __init__(self, |
| 56 | + *, |
| 57 | + vllm_config: VllmConfig, |
| 58 | + prefix: str = '', |
| 59 | + **kwargs) -> None: |
| 60 | + super().__init__() |
| 61 | + |
| 62 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 63 | + x = x + x |
| 64 | + attn_output = torch.empty_like(x) |
| 65 | + torch.ops.silly.attention(x, x, x, attn_output) |
| 66 | + x = attn_output |
| 67 | + x = x * 3 |
| 68 | + return x |
| 69 | + |
| 70 | + @ignore_torch_compile |
| 71 | + class B(A): |
| 72 | + ... |
| 73 | + |
| 74 | + @support_torch_compile |
| 75 | + class C(B): |
| 76 | + ... |
| 77 | + |
| 78 | + with set_current_vllm_config(vllm_config): |
| 79 | + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() |
| 80 | + |
| 81 | + # A has support_torch_compile |
| 82 | + with compilation_counter.expect( |
| 83 | + num_graphs_seen=1, |
| 84 | + num_piecewise_graphs_seen=3, |
| 85 | + num_piecewise_capturable_graphs_seen=2, |
| 86 | + num_backend_compilations=2, |
| 87 | + num_cudagraph_captured=4, |
| 88 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 89 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 90 | + # first run is for compile |
| 91 | + mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 92 | + # run cudagraph captured sizes |
| 93 | + mod_A(torch.randn(2, MLP_SIZE).cuda()) |
| 94 | + mod_A(torch.randn(1, MLP_SIZE).cuda()) |
| 95 | + |
| 96 | + with set_current_vllm_config(vllm_config): |
| 97 | + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() |
| 98 | + |
| 99 | + # B's ignore_torch_compile should override A's support_torch_compile |
| 100 | + with compilation_counter.expect( |
| 101 | + num_graphs_seen=0, |
| 102 | + num_piecewise_graphs_seen=0, |
| 103 | + num_piecewise_capturable_graphs_seen=0, |
| 104 | + num_backend_compilations=0, |
| 105 | + num_cudagraph_captured=0, |
| 106 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 107 | + mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 108 | + mod_B(torch.randn(2, MLP_SIZE).cuda()) |
| 109 | + mod_B(torch.randn(1, MLP_SIZE).cuda()) |
| 110 | + |
| 111 | + with set_current_vllm_config(vllm_config): |
| 112 | + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() |
| 113 | + |
| 114 | + # C's support_torch_compile should override B's ignore_torch_compile |
| 115 | + with compilation_counter.expect( |
| 116 | + num_graphs_seen=1, |
| 117 | + num_piecewise_graphs_seen=3, |
| 118 | + num_piecewise_capturable_graphs_seen=2, |
| 119 | + num_backend_compilations=2, |
| 120 | + num_cudagraph_captured=4, |
| 121 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 122 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 123 | + mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 124 | + mod_C(torch.randn(2, MLP_SIZE).cuda()) |
| 125 | + mod_C(torch.randn(1, MLP_SIZE).cuda()) |
| 126 | + |
| 127 | + |
| 128 | +# Only enable torch.compile if |
| 129 | +# vllm_config.cache_config.kv_sharing_fast_prefill=True |
| 130 | +@support_torch_compile(compile_cond=lambda vllm_config: vllm_config. |
| 131 | + cache_config.kv_sharing_fast_prefill) |
| 132 | +class B(nn.Module): |
| 133 | + |
| 134 | + def __init__(self, |
| 135 | + *, |
| 136 | + vllm_config: VllmConfig, |
| 137 | + prefix: str = '', |
| 138 | + **kwargs) -> None: |
| 139 | + super().__init__() |
| 140 | + |
| 141 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 142 | + x = x + x |
| 143 | + attn_output = torch.empty_like(x) |
| 144 | + torch.ops.silly.attention(x, x, x, attn_output) |
| 145 | + x = attn_output |
| 146 | + x = x + x |
| 147 | + return x |
| 148 | + |
| 149 | + |
| 150 | +# Only enable torch.compile if |
| 151 | +# vllm_config.cache_config.kv_sharing_fast_prefill=False |
| 152 | +@support_torch_compile(compile_cond=lambda vllm_config: not vllm_config. |
| 153 | + cache_config.kv_sharing_fast_prefill) |
| 154 | +class A(nn.Module): |
| 155 | + |
| 156 | + def __init__(self, |
| 157 | + *, |
| 158 | + vllm_config: VllmConfig, |
| 159 | + prefix: str = '', |
| 160 | + **kwargs) -> None: |
| 161 | + super().__init__() |
| 162 | + self.mod1 = B(vllm_config=vllm_config, perfix=prefix, **kwargs) |
| 163 | + self.mod2 = B(vllm_config=vllm_config, perfix=prefix, **kwargs) |
| 164 | + |
| 165 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 166 | + x = self.mod1(x) |
| 167 | + attn_output = torch.empty_like(x) |
| 168 | + torch.ops.silly.attention(x, x, x, attn_output) |
| 169 | + x = attn_output |
| 170 | + x = self.mod2(x) |
| 171 | + return x |
| 172 | + |
| 173 | + |
| 174 | +def test_support_torch_compile_cond(): |
| 175 | + vllm_config = VllmConfig(cache_config=CacheConfig( |
| 176 | + kv_sharing_fast_prefill=True, ), |
| 177 | + compilation_config=CompilationConfig( |
| 178 | + level=CompilationLevel.PIECEWISE, |
| 179 | + use_cudagraph=True, |
| 180 | + splitting_ops=["silly.attention"], |
| 181 | + cudagraph_capture_sizes=[1, 2], |
| 182 | + )) |
| 183 | + |
| 184 | + with set_current_vllm_config(vllm_config): |
| 185 | + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() |
| 186 | + |
| 187 | + # A has support_torch_compile but compile_cond is not satisified |
| 188 | + # compile_cond will be satisified for B, so we expect mod1 and mod2 |
| 189 | + # to be compiled |
| 190 | + with compilation_counter.expect( |
| 191 | + num_graphs_seen=2, |
| 192 | + num_piecewise_graphs_seen=6, |
| 193 | + # 3 piecewise graphs per instance of B() |
| 194 | + num_piecewise_capturable_graphs_seen=4, |
| 195 | + num_backend_compilations=4, |
| 196 | + num_cudagraph_captured=8, |
| 197 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 198 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 199 | + # first run is for compile |
| 200 | + mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 201 | + # run cudagraph captured sizes |
| 202 | + mod_A(torch.randn(2, MLP_SIZE).cuda()) |
| 203 | + mod_A(torch.randn(1, MLP_SIZE).cuda()) |
| 204 | + |
| 205 | + # Set kv_sharing_fast_prefill=False |
| 206 | + # which will cause A to be compiled and B to not be compiled |
| 207 | + vllm_config = VllmConfig(cache_config=CacheConfig( |
| 208 | + kv_sharing_fast_prefill=False, ), |
| 209 | + compilation_config=CompilationConfig( |
| 210 | + level=CompilationLevel.PIECEWISE, |
| 211 | + use_cudagraph=True, |
| 212 | + splitting_ops=["silly.attention"], |
| 213 | + cudagraph_capture_sizes=[1, 2], |
| 214 | + )) |
| 215 | + |
| 216 | + with set_current_vllm_config(vllm_config): |
| 217 | + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() |
| 218 | + |
| 219 | + with compilation_counter.expect( |
| 220 | + num_graphs_seen=1, |
| 221 | + num_piecewise_graphs_seen=7, |
| 222 | + # 3 attn ops and 4 non-attn ops |
| 223 | + num_piecewise_capturable_graphs_seen=4, |
| 224 | + num_backend_compilations=4, |
| 225 | + num_cudagraph_captured=8, |
| 226 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 227 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 228 | + # first run is for compile |
| 229 | + mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 230 | + # run cudagraph captured sizes |
| 231 | + mod_A(torch.randn(2, MLP_SIZE).cuda()) |
| 232 | + mod_A(torch.randn(1, MLP_SIZE).cuda()) |
0 commit comments