|
20 | 20 | import atexit
|
21 | 21 | import functools
|
22 | 22 | import math
|
| 23 | +import os |
23 | 24 | from contextlib import contextmanager
|
24 | 25 | from enum import Enum
|
25 | 26 | from threading import Lock
|
@@ -303,16 +304,47 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
303 | 304 | parallel_config = vllm_config.parallel_config
|
304 | 305 |
|
305 | 306 | # TODO: Find out whether we need to take into account the pp_size
|
306 |
| - parallel_factor = 1 + sum(size > 1 for size in [ |
307 |
| - parallel_config.data_parallel_size_local, |
| 307 | + num_comm_groups = sum(size > 1 for size in [ |
| 308 | + parallel_config.data_parallel_size, |
308 | 309 | parallel_config.tensor_parallel_size,
|
309 | 310 | ])
|
310 | 311 |
|
311 |
| - # Calculate maximum supported batch sizes considering model architecture |
312 |
| - max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / |
313 |
| - (num_hidden_layers + 1) / parallel_factor) |
314 |
| - logger.info("Calculated maximum supported batch sizes for ACL graph: %s", |
315 |
| - max_num_batch_sizes) |
| 312 | + if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': |
| 313 | + # TODO: Find out whether we need to take into account the pp_size |
| 314 | + parallel_factor = 1 + num_comm_groups + int( |
| 315 | + parallel_config.enable_expert_parallel) |
| 316 | + # Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device |
| 317 | + # Assume the following case: |
| 318 | + # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, |
| 319 | + # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 |
| 320 | + max_num_batch_sizes = math.floor( |
| 321 | + MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor) |
| 322 | + logger.info( |
| 323 | + "Calculated maximum supported batch sizes for ACL graph: %s", |
| 324 | + max_num_batch_sizes) |
| 325 | + else: |
| 326 | + # The above describes an empirical formula applicable to the A2 hardware. |
| 327 | + # Under this configuration, HCCL employs the FFTS+ method for execution unfolding, |
| 328 | + # which adds only 1 concurrent stream without consuming collective communication execution unfolding streams. |
| 329 | + # On A3 hardware, HCCL defaults to the AICPU method. |
| 330 | + # This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case). |
| 331 | + # Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes. |
| 332 | + # Therefore, the calculation formula has been modified as follows: |
| 333 | + # Assume the following case: |
| 334 | + # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, |
| 335 | + # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 |
| 336 | + max_num_batch_sizes = math.floor( |
| 337 | + (MAX_CAPTURE_SIZE - num_comm_groups * 40) / |
| 338 | + (num_hidden_layers + 1) / (1 + num_comm_groups * 2)) |
| 339 | + logger.info( |
| 340 | + "Calculated maximum supported batch sizes for ACL graph: %s", |
| 341 | + max_num_batch_sizes) |
| 342 | + logger.warning( |
| 343 | + "Currently, communication is performed using FFTS+ method, which reduces " |
| 344 | + "the number of available streams and, as a result, limits the range of runtime " |
| 345 | + "shapes that can be handled. To both improve communication performance and " |
| 346 | + "increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV." |
| 347 | + ) |
316 | 348 |
|
317 | 349 | # If original sizes exceed maximum, sample a representative subset
|
318 | 350 | if max_num_batch_sizes < len(original_sizes):
|
|
0 commit comments