@@ -325,17 +325,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
325
325
num_hidden_layers = get_max_hidden_layers (hf_config )
326
326
parallel_config = vllm_config .parallel_config
327
327
328
+ if os .getenv ("HCCL_OP_EXPANSION_MODE" )== 'AIV' :
328
329
# TODO: Find out whether we need to take into account the pp_size
329
- parallel_factor = 1 + sum (size > 1 for size in [
330
- parallel_config .data_parallel_size_local ,
331
- parallel_config .tensor_parallel_size ,
332
- ])
333
-
334
- # Calculate maximum supported batch sizes considering model architecture
335
- max_num_batch_sizes = math .floor (MAX_CAPTURE_SIZE /
336
- (num_hidden_layers + 1 ) / parallel_factor )
337
- logger .info ("Calculated maximum supported batch sizes for ACL graph: %s" ,
338
- max_num_batch_sizes )
330
+ parallel_factor = 1 + sum (size > 1 for size in [
331
+ parallel_config .data_parallel_size ,
332
+ parallel_config .tensor_parallel_size ,
333
+ ])
334
+
335
+ # Calculate maximum supported batch sizes considering model architecture
336
+ max_num_batch_sizes = math .floor (MAX_CAPTURE_SIZE /
337
+ (num_hidden_layers + 1 ) / parallel_factor )
338
+ logger .info ("Calculated maximum supported batch sizes for ACL graph: %s" ,
339
+ max_num_batch_sizes )
340
+ else :
341
+ num_comm_groups = sum (size > 1 for size in [
342
+ parallel_config .data_parallel_size ,
343
+ parallel_config .tensor_parallel_size ,
344
+ ])
345
+
346
+ max_num_batch_sizes = math .floor ((MAX_CAPTURE_SIZE - num_comm_groups * 40 ) / (num_hidden_layers + 1 ) / (1 + num_comm_groups * 2 ))
347
+ logger .info ("Calculated maximum supported batch sizes for ACL graph: %s" ,
348
+ max_num_batch_sizes )
349
+ logger .warning ("Unset HCCL_OP_EXPANSION_MODE prevents max size capture. Setting HCCL_OP_EXPANSION_MODE=AIV captures max sizes and boosts ACL graph performance." )
339
350
340
351
# If original sizes exceed maximum, sample a representative subset
341
352
if max_num_batch_sizes < len (original_sizes ):
0 commit comments