|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +Test (piecewise) compilation with a simple model where multiple submodules |
| 5 | +are compiled and graph captured separately. |
| 6 | +""" |
| 7 | +import torch |
| 8 | +from torch import nn |
| 9 | +from torch.library import Library |
| 10 | + |
| 11 | +from vllm.compilation.backends import set_model_tag |
| 12 | +from vllm.compilation.counter import compilation_counter |
| 13 | +from vllm.compilation.decorators import (ignore_torch_compile, |
| 14 | + support_torch_compile) |
| 15 | +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, |
| 16 | + set_current_vllm_config) |
| 17 | +from vllm.envs import VLLM_USE_V1 |
| 18 | +from vllm.forward_context import set_forward_context |
| 19 | +from vllm.utils import direct_register_custom_op |
| 20 | + |
| 21 | +# create a library to hold the custom op |
| 22 | +silly_lib = Library("silly", "FRAGMENT") # noqa |
| 23 | + |
| 24 | +BATCH_SIZE = 32 |
| 25 | +MLP_SIZE = 128 |
| 26 | +HIDDEN_SIZE = 1024 |
| 27 | +RANDOM_SEED = 0 |
| 28 | + |
| 29 | + |
| 30 | +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
| 31 | + out: torch.Tensor) -> None: |
| 32 | + out.copy_(q) |
| 33 | + out += k |
| 34 | + out += v |
| 35 | + |
| 36 | + |
| 37 | +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
| 38 | + out: torch.Tensor) -> None: |
| 39 | + return |
| 40 | + |
| 41 | + |
| 42 | +direct_register_custom_op( |
| 43 | + op_name="attention", |
| 44 | + op_func=silly_attention, |
| 45 | + mutates_args=["out"], |
| 46 | + fake_impl=silly_attention_fake, |
| 47 | + target_lib=silly_lib, |
| 48 | +) |
| 49 | + |
| 50 | + |
| 51 | +@support_torch_compile |
| 52 | +class ParentModel(nn.Module): |
| 53 | + |
| 54 | + def __init__(self, |
| 55 | + *, |
| 56 | + vllm_config: VllmConfig, |
| 57 | + prefix: str = '', |
| 58 | + **kwargs) -> None: |
| 59 | + super().__init__() |
| 60 | + |
| 61 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 62 | + return x |
| 63 | + |
| 64 | + |
| 65 | +class Attention(nn.Module): |
| 66 | + |
| 67 | + def __init__(self, mlp_size: int, hidden_size: int) -> None: |
| 68 | + super().__init__() |
| 69 | + self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) |
| 70 | + self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False) |
| 71 | + self.rms_norm_weight = nn.Parameter(torch.ones(hidden_size)) |
| 72 | + |
| 73 | + # Initialize to same weights for testing |
| 74 | + nn.init.xavier_normal_( |
| 75 | + self.pre_attn.weight.data, |
| 76 | + generator=torch.Generator().manual_seed(RANDOM_SEED), |
| 77 | + gain=0.001) |
| 78 | + nn.init.xavier_normal_( |
| 79 | + self.post_attn.weight.data, |
| 80 | + generator=torch.Generator().manual_seed(RANDOM_SEED), |
| 81 | + gain=0.001) |
| 82 | + |
| 83 | + def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor: |
| 84 | + x_f32 = x.float() |
| 85 | + return (x_f32 * torch.rsqrt( |
| 86 | + torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) * |
| 87 | + self.rms_norm_weight).to(x.dtype) |
| 88 | + |
| 89 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 90 | + x = self.pre_attn(x) |
| 91 | + x = self.rms_norm_ref(x) |
| 92 | + attn_output = torch.empty_like(x) |
| 93 | + torch.ops.silly.attention(x, x, x, attn_output) |
| 94 | + x = attn_output |
| 95 | + x = self.rms_norm_ref(x) |
| 96 | + x = self.post_attn(x) |
| 97 | + return x |
| 98 | + |
| 99 | + |
| 100 | +@support_torch_compile |
| 101 | +class CompiledAttention(nn.Module): |
| 102 | + |
| 103 | + def __init__(self, |
| 104 | + *, |
| 105 | + mlp_size: int, |
| 106 | + hidden_size: int, |
| 107 | + vllm_config: VllmConfig, |
| 108 | + prefix: str = '', |
| 109 | + **kwargs) -> None: |
| 110 | + super().__init__() |
| 111 | + self.attn = Attention(mlp_size, hidden_size) |
| 112 | + |
| 113 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 114 | + return self.attn(x) |
| 115 | + |
| 116 | + |
| 117 | +@support_torch_compile |
| 118 | +class CompiledAttentionTwo(CompiledAttention): |
| 119 | + |
| 120 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 121 | + return self.attn(x) + x |
| 122 | + |
| 123 | + |
| 124 | +@ignore_torch_compile |
| 125 | +class SimpleModelWithTwoGraphs(ParentModel): |
| 126 | + |
| 127 | + def __init__(self, |
| 128 | + *, |
| 129 | + mlp_size: int, |
| 130 | + hidden_size: int, |
| 131 | + vllm_config: VllmConfig, |
| 132 | + prefix: str = '', |
| 133 | + **kwargs) -> None: |
| 134 | + super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 135 | + # Test will fail without set_model_tag here with error: |
| 136 | + # "ValueError: too many values to unpack (expected 3)" |
| 137 | + # This is because CompiledAttention and CompiledAttentionTwo |
| 138 | + # have different implmentations but the same torch.compile |
| 139 | + # cache dir will be used as default prefix is 'model_tag' |
| 140 | + with set_model_tag("attn_one"): |
| 141 | + self.attn_one = CompiledAttention( |
| 142 | + mlp_size=mlp_size, |
| 143 | + hidden_size=hidden_size, |
| 144 | + vllm_config=vllm_config, |
| 145 | + prefix=f"{prefix}.attn_one", |
| 146 | + ) |
| 147 | + with set_model_tag("attn_two"): |
| 148 | + self.attn_two = CompiledAttentionTwo( |
| 149 | + mlp_size=mlp_size, |
| 150 | + hidden_size=hidden_size, |
| 151 | + vllm_config=vllm_config, |
| 152 | + prefix=f"{prefix}.attn_two", |
| 153 | + ) |
| 154 | + |
| 155 | + self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda() |
| 156 | + |
| 157 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 158 | + bsz = x.shape[0] |
| 159 | + # CUDAGraph expects same tensor addresses for each run |
| 160 | + self.hidden_states[:bsz].copy_(x) |
| 161 | + x = self.attn_one(self.hidden_states[:bsz]) |
| 162 | + self.hidden_states[:bsz].copy_(x) |
| 163 | + x = self.attn_two(self.hidden_states[:bsz]) |
| 164 | + return x |
| 165 | + |
| 166 | + |
| 167 | +def test_ignore_torch_compile_decorator(): |
| 168 | + assert VLLM_USE_V1 |
| 169 | + |
| 170 | + # piecewise |
| 171 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 172 | + level=CompilationLevel.PIECEWISE, |
| 173 | + use_cudagraph=True, |
| 174 | + splitting_ops=["silly.attention"], |
| 175 | + cudagraph_capture_sizes=[1, 2], |
| 176 | + )) |
| 177 | + |
| 178 | + @support_torch_compile |
| 179 | + class A(nn.Module): |
| 180 | + |
| 181 | + def __init__(self, |
| 182 | + *, |
| 183 | + vllm_config: VllmConfig, |
| 184 | + prefix: str = '', |
| 185 | + **kwargs) -> None: |
| 186 | + super().__init__() |
| 187 | + |
| 188 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 189 | + x = x + x |
| 190 | + attn_output = torch.empty_like(x) |
| 191 | + torch.ops.silly.attention(x, x, x, attn_output) |
| 192 | + x = attn_output |
| 193 | + x = x * 3 |
| 194 | + return x |
| 195 | + |
| 196 | + @ignore_torch_compile |
| 197 | + class B(A): |
| 198 | + ... |
| 199 | + |
| 200 | + @support_torch_compile |
| 201 | + class C(B): |
| 202 | + ... |
| 203 | + |
| 204 | + with set_current_vllm_config(vllm_config): |
| 205 | + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() |
| 206 | + |
| 207 | + # A has support_torch_compile |
| 208 | + with compilation_counter.expect( |
| 209 | + num_graphs_seen=1, |
| 210 | + num_piecewise_graphs_seen=3, |
| 211 | + num_piecewise_capturable_graphs_seen=2, |
| 212 | + num_backend_compilations=2, |
| 213 | + num_cudagraph_captured=4, |
| 214 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 215 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 216 | + # first run is for compile |
| 217 | + mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 218 | + # run cudagraph captured sizes |
| 219 | + mod_A(torch.randn(2, MLP_SIZE).cuda()) |
| 220 | + mod_A(torch.randn(1, MLP_SIZE).cuda()) |
| 221 | + |
| 222 | + with set_current_vllm_config(vllm_config): |
| 223 | + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() |
| 224 | + |
| 225 | + # B's ignore_torch_compile should override A's support_torch_compile |
| 226 | + with compilation_counter.expect( |
| 227 | + num_graphs_seen=0, |
| 228 | + num_piecewise_graphs_seen=0, |
| 229 | + num_piecewise_capturable_graphs_seen=0, |
| 230 | + num_backend_compilations=0, |
| 231 | + num_cudagraph_captured=0, |
| 232 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 233 | + mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 234 | + mod_B(torch.randn(2, MLP_SIZE).cuda()) |
| 235 | + mod_B(torch.randn(1, MLP_SIZE).cuda()) |
| 236 | + |
| 237 | + with set_current_vllm_config(vllm_config): |
| 238 | + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() |
| 239 | + |
| 240 | + # C's support_torch_compile should override B's ignore_torch_compile |
| 241 | + with compilation_counter.expect( |
| 242 | + num_graphs_seen=1, |
| 243 | + num_piecewise_graphs_seen=3, |
| 244 | + num_piecewise_capturable_graphs_seen=2, |
| 245 | + num_backend_compilations=2, |
| 246 | + num_cudagraph_captured=4, |
| 247 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 248 | + ), set_forward_context({}, vllm_config=vllm_config): |
| 249 | + mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) |
| 250 | + mod_C(torch.randn(2, MLP_SIZE).cuda()) |
| 251 | + mod_C(torch.randn(1, MLP_SIZE).cuda()) |
| 252 | + |
| 253 | + |
| 254 | +@torch.inference_mode |
| 255 | +def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): |
| 256 | + with set_forward_context({}, vllm_config=vllm_config): |
| 257 | + # First run is for compile |
| 258 | + model(inputs) |
| 259 | + |
| 260 | + # Run CUDAGraph captured sizes |
| 261 | + model(inputs[:2]) |
| 262 | + model(inputs[:1]) |
| 263 | + |
| 264 | + output = model(inputs[:2]) |
| 265 | + |
| 266 | + output = output.cpu() |
| 267 | + return output.cpu() |
| 268 | + |
| 269 | + |
| 270 | +def test_multi_graph_piecewise_compile_outputs_equal(): |
| 271 | + outputs = [] |
| 272 | + |
| 273 | + # piecewise compile |
| 274 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 275 | + level=CompilationLevel.PIECEWISE, |
| 276 | + use_cudagraph=True, |
| 277 | + splitting_ops=["silly.attention"], |
| 278 | + cudagraph_capture_sizes=[1, 2], |
| 279 | + )) |
| 280 | + |
| 281 | + with set_current_vllm_config(vllm_config): |
| 282 | + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, |
| 283 | + hidden_size=HIDDEN_SIZE, |
| 284 | + vllm_config=vllm_config, |
| 285 | + prefix='').eval().cuda() |
| 286 | + |
| 287 | + # Pre-allocate memory for CUDAGraph which expects |
| 288 | + # static tensor addresses |
| 289 | + inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() |
| 290 | + |
| 291 | + with compilation_counter.expect( |
| 292 | + num_graphs_seen=2, # two graphs for the model |
| 293 | + num_piecewise_graphs_seen=6, |
| 294 | + # attn_one, attn_two each has 3 piecewise graphs |
| 295 | + # (pre attn, post attn, silly_attention) each |
| 296 | + num_piecewise_capturable_graphs_seen=4, |
| 297 | + # attn_one, attn_two has pre attn and post attn each, total=4 |
| 298 | + num_backend_compilations=4, # num_piecewise_capturable_graphs_seen |
| 299 | + num_cudagraph_captured=8, |
| 300 | + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen |
| 301 | + ): |
| 302 | + outputs.append(run_model(vllm_config, model, inputs)) |
| 303 | + |
| 304 | + # no compile or cudagraph |
| 305 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 306 | + level=CompilationLevel.NO_COMPILATION, )) |
| 307 | + |
| 308 | + with set_current_vllm_config(vllm_config): |
| 309 | + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, |
| 310 | + hidden_size=HIDDEN_SIZE, |
| 311 | + vllm_config=vllm_config, |
| 312 | + prefix='').eval().cuda() |
| 313 | + |
| 314 | + with compilation_counter.expect( |
| 315 | + num_graphs_seen=0, |
| 316 | + num_piecewise_graphs_seen=0, |
| 317 | + num_piecewise_capturable_graphs_seen=0, |
| 318 | + num_backend_compilations=0, |
| 319 | + num_cudagraph_captured=0, |
| 320 | + ): |
| 321 | + outputs.append(run_model(vllm_config, model, inputs)) |
| 322 | + |
| 323 | + # piecewise compile without CUDA graph |
| 324 | + vllm_config = VllmConfig(compilation_config=CompilationConfig( |
| 325 | + level=CompilationLevel.PIECEWISE, |
| 326 | + use_cudagraph=False, |
| 327 | + splitting_ops=["silly.attention"], |
| 328 | + )) |
| 329 | + |
| 330 | + with set_current_vllm_config(vllm_config): |
| 331 | + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, |
| 332 | + hidden_size=HIDDEN_SIZE, |
| 333 | + vllm_config=vllm_config, |
| 334 | + prefix='').eval().cuda() |
| 335 | + |
| 336 | + with compilation_counter.expect( |
| 337 | + num_graphs_seen=2, |
| 338 | + num_piecewise_graphs_seen=6, |
| 339 | + num_piecewise_capturable_graphs_seen=4, |
| 340 | + num_backend_compilations=4, |
| 341 | + num_cudagraph_captured=0, # no cudagraph captured |
| 342 | + ): |
| 343 | + outputs.append(run_model(vllm_config, model, inputs)) |
| 344 | + |
| 345 | + # Generally don't expect outputs with and without inductor |
| 346 | + # to be bitwise equivalent |
| 347 | + assert torch.allclose(outputs[0], outputs[1]) |
| 348 | + |
| 349 | + # Expect bitwise equivalence using inductor w/ and w/o cudagraph |
| 350 | + assert torch.equal(outputs[0], outputs[2]) |
0 commit comments