|
17 | 17 | # Adapted from vllm-project/vllm/vllm/worker/worker.py
|
18 | 18 | #
|
19 | 19 |
|
| 20 | +import os |
20 | 21 | import atexit
|
21 | 22 | import functools
|
22 | 23 | import math
|
@@ -325,17 +326,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
325 | 326 | num_hidden_layers = get_max_hidden_layers(hf_config)
|
326 | 327 | parallel_config = vllm_config.parallel_config
|
327 | 328 |
|
| 329 | + if os.getenv("HCCL_OP_EXPANSION_MODE")=='AIV': |
328 | 330 | # TODO: Find out whether we need to take into account the pp_size
|
329 |
| - parallel_factor = 1 + sum(size > 1 for size in [ |
330 |
| - parallel_config.data_parallel_size_local, |
331 |
| - parallel_config.tensor_parallel_size, |
332 |
| - ]) |
333 |
| - |
334 |
| - # Calculate maximum supported batch sizes considering model architecture |
335 |
| - max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / |
336 |
| - (num_hidden_layers + 1) / parallel_factor) |
337 |
| - logger.info("Calculated maximum supported batch sizes for ACL graph: %s", |
338 |
| - max_num_batch_sizes) |
| 331 | + parallel_factor = 1 + sum(size > 1 for size in [ |
| 332 | + parallel_config.data_parallel_size, |
| 333 | + parallel_config.tensor_parallel_size, |
| 334 | + ]) |
| 335 | + |
| 336 | + # Calculate maximum supported batch sizes considering model architecture |
| 337 | + max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / |
| 338 | + (num_hidden_layers + 1) / parallel_factor) |
| 339 | + logger.info("Calculated maximum supported batch sizes for ACL graph: %s", |
| 340 | + max_num_batch_sizes) |
| 341 | + else: |
| 342 | + num_comm_groups = sum(size > 1 for size in [ |
| 343 | + parallel_config.data_parallel_size, |
| 344 | + parallel_config.tensor_parallel_size, |
| 345 | + ]) |
| 346 | + |
| 347 | + max_num_batch_sizes = math.floor((MAX_CAPTURE_SIZE - num_comm_groups * 40) / (num_hidden_layers + 1) / (1 + num_comm_groups * 2)) |
| 348 | + logger.info("Calculated maximum supported batch sizes for ACL graph: %s", |
| 349 | + max_num_batch_sizes) |
| 350 | + logger.warning("Unset HCCL_OP_EXPANSION_MODE prevents max size capture. Setting HCCL_OP_EXPANSION_MODE=AIV captures max sizes and boosts ACL graph performance.") |
339 | 351 |
|
340 | 352 | # If original sizes exceed maximum, sample a representative subset
|
341 | 353 | if max_num_batch_sizes < len(original_sizes):
|
|
0 commit comments