@@ -2321,6 +2321,12 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
2321
2321
positions = positions ,
2322
2322
intermediate_tensors = intermediate_tensors ,
2323
2323
inputs_embeds = inputs_embeds )
2324
+ forward_context = get_forward_context ()
2325
+ assert forward_context is not None
2326
+ if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
2327
+ update_attn_params (self .update_stream , forward_context ,
2328
+ positions .shape [0 ])
2329
+
2324
2330
if self .drafter and self .drafter .name == SpecDcodeType .EAGLE3 :
2325
2331
hidden_states , _ = hidden_states
2326
2332
else :
@@ -2333,12 +2339,12 @@ def _dummy_run(
2333
2339
num_tokens : int ,
2334
2340
with_prefill : bool = False ,
2335
2341
is_torchair_compile : bool = False ,
2336
- aclgraph_runtime_mode : CUDAGraphMode = CUDAGraphMode . NONE ,
2342
+ aclgraph_runtime_mode : Optional [ CUDAGraphMode ] = None ,
2337
2343
force_attention : bool = False ,
2338
2344
uniform_decode : bool = False ,
2339
2345
) -> torch .Tensor :
2340
2346
# only support eager mode and piecewise graph now
2341
- assert aclgraph_runtime_mode in {
2347
+ assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
2342
2348
CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL
2343
2349
}
2344
2350
@@ -2371,8 +2377,6 @@ def _dummy_run(
2371
2377
max_num_reqs = self .scheduler_config .max_num_seqs
2372
2378
if uniform_decode :
2373
2379
num_reqs = cdiv (num_tokens , max_query_len )
2374
- assert num_reqs <= max_num_reqs , \
2375
- "Do not capture num_reqs > max_num_reqs for uniform batch"
2376
2380
num_scheduled_tokens_list = [max_query_len ] * num_reqs
2377
2381
if num_tokens % max_query_len != 0 :
2378
2382
num_scheduled_tokens_list [- 1 ] = num_tokens % max_query_len
@@ -2395,12 +2399,13 @@ def _dummy_run(
2395
2399
if self .is_kv_producer and not self .is_kv_consumer :
2396
2400
with_prefill = True
2397
2401
2402
+ # TODO(cmq): check if with_prefill is reasonable
2398
2403
attn_metadata = self ._build_attention_metadata (
2399
- with_prefill ,
2400
- num_reqs ,
2401
- num_tokens ,
2402
- max_query_len ,
2403
- force_attention ,
2404
+ False ,
2405
+ num_reqs = num_reqs ,
2406
+ num_tokens = num_tokens ,
2407
+ max_query_len = max_query_len ,
2408
+ force_attention = force_attention ,
2404
2409
)
2405
2410
2406
2411
if not self .in_profile_run and self .dynamic_eplb :
@@ -2433,18 +2438,21 @@ def _dummy_run(
2433
2438
k : v [:num_tokens ]
2434
2439
for k , v in self .intermediate_tensors .items ()
2435
2440
})
2436
- if aclgraph_runtime_mode == CUDAGraphMode .NONE :
2437
- batch_descriptor = None
2438
- else :
2439
- # filter out the valid batch descriptor
2440
- _cg_mode , batch_descriptor = \
2441
- self .aclgraph_dispatcher .dispatch (
2442
- BatchDescriptor (num_tokens = num_tokens ,
2443
- uniform_decode = uniform_decode ))
2444
- # sanity check
2445
- assert aclgraph_runtime_mode == _cg_mode , (
2441
+
2442
+ # filter out the valid batch descriptor
2443
+ _ag_mode , batch_descriptor = \
2444
+ self .aclgraph_dispatcher .dispatch (
2445
+ BatchDescriptor (num_tokens = num_tokens ,
2446
+ uniform_decode = uniform_decode ))
2447
+ if aclgraph_runtime_mode is not None :
2448
+ # we allow forcing NONE when the dispatcher disagrees to support
2449
+ # warm ups for aclgraph capture
2450
+ assert aclgraph_runtime_mode == CUDAGraphMode .NONE or \
2451
+ aclgraph_runtime_mode == _ag_mode , (
2446
2452
f"Aclgraph runtime mode mismatch at dummy_run. "
2447
- f"Expected { _cg_mode } , but got { aclgraph_runtime_mode } ." )
2453
+ f"Expected { _ag_mode } , but got { aclgraph_runtime_mode } ." )
2454
+ else :
2455
+ aclgraph_runtime_mode = _ag_mode
2448
2456
2449
2457
need_dummy_logits = (not self .in_profile_run
2450
2458
and lmhead_tp_enable ())
0 commit comments