Skip to content

Commit f8c93d8

Browse files
authored
[Aclgraph][DP] Fix dp dummy run not in aclgraph error (#3208)
### What this PR does / why we need it? When running DP in a non-equilibrium scenario, which means there is some dp groups executing `dummy_run`, we need to make sure it running the same mode as other dp, thus improving then performance in dp scenario ### How was this patch tested? Tested by adding log in `_dummy_run` - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@releases/v0.11.0 --------- Signed-off-by: MengqingCao <[email protected]>
1 parent ddf4d53 commit f8c93d8

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

tests/ut/worker/test_worker_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,14 +444,17 @@ def test_execute_dummy_batch(self):
444444
# Create worker mock
445445
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
446446
worker = NPUWorker()
447+
worker.compilation_config = MagicMock()
448+
worker.compilation_config.cudagraph_mode = MagicMock()
447449
mock_model_runner = MagicMock()
448450
worker.model_runner = mock_model_runner
449451

450452
# Test execute_dummy_batch
451453
worker.execute_dummy_batch()
452454

453455
# Verify call
454-
mock_model_runner._dummy_run.assert_called_once_with(1)
456+
mock_model_runner._dummy_run.assert_called_once_with(
457+
num_tokens=1, uniform_decode=True, force_attention=False)
455458

456459
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
457460
@patch("vllm_ascend.worker.worker_v1.logger")

vllm_ascend/worker/model_runner_v1.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2321,6 +2321,12 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
23212321
positions=positions,
23222322
intermediate_tensors=intermediate_tensors,
23232323
inputs_embeds=inputs_embeds)
2324+
forward_context = get_forward_context()
2325+
assert forward_context is not None
2326+
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
2327+
update_attn_params(self.update_stream, forward_context,
2328+
positions.shape[0])
2329+
23242330
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
23252331
hidden_states, _ = hidden_states
23262332
else:
@@ -2333,12 +2339,12 @@ def _dummy_run(
23332339
num_tokens: int,
23342340
with_prefill: bool = False,
23352341
is_torchair_compile: bool = False,
2336-
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
2342+
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
23372343
force_attention: bool = False,
23382344
uniform_decode: bool = False,
23392345
) -> torch.Tensor:
23402346
# only support eager mode and piecewise graph now
2341-
assert aclgraph_runtime_mode in {
2347+
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
23422348
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
23432349
}
23442350

@@ -2371,8 +2377,6 @@ def _dummy_run(
23712377
max_num_reqs = self.scheduler_config.max_num_seqs
23722378
if uniform_decode:
23732379
num_reqs = cdiv(num_tokens, max_query_len)
2374-
assert num_reqs <= max_num_reqs, \
2375-
"Do not capture num_reqs > max_num_reqs for uniform batch"
23762380
num_scheduled_tokens_list = [max_query_len] * num_reqs
23772381
if num_tokens % max_query_len != 0:
23782382
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
@@ -2395,12 +2399,13 @@ def _dummy_run(
23952399
if self.is_kv_producer and not self.is_kv_consumer:
23962400
with_prefill = True
23972401

2402+
# TODO(cmq): check if with_prefill is reasonable
23982403
attn_metadata = self._build_attention_metadata(
2399-
with_prefill,
2400-
num_reqs,
2401-
num_tokens,
2402-
max_query_len,
2403-
force_attention,
2404+
False,
2405+
num_reqs=num_reqs,
2406+
num_tokens=num_tokens,
2407+
max_query_len=max_query_len,
2408+
force_attention=force_attention,
24042409
)
24052410

24062411
if not self.in_profile_run and self.dynamic_eplb:
@@ -2433,18 +2438,21 @@ def _dummy_run(
24332438
k: v[:num_tokens]
24342439
for k, v in self.intermediate_tensors.items()
24352440
})
2436-
if aclgraph_runtime_mode == CUDAGraphMode.NONE:
2437-
batch_descriptor = None
2438-
else:
2439-
# filter out the valid batch descriptor
2440-
_cg_mode, batch_descriptor = \
2441-
self.aclgraph_dispatcher.dispatch(
2442-
BatchDescriptor(num_tokens=num_tokens,
2443-
uniform_decode=uniform_decode))
2444-
# sanity check
2445-
assert aclgraph_runtime_mode == _cg_mode, (
2441+
2442+
# filter out the valid batch descriptor
2443+
_ag_mode, batch_descriptor = \
2444+
self.aclgraph_dispatcher.dispatch(
2445+
BatchDescriptor(num_tokens=num_tokens,
2446+
uniform_decode=uniform_decode))
2447+
if aclgraph_runtime_mode is not None:
2448+
# we allow forcing NONE when the dispatcher disagrees to support
2449+
# warm ups for aclgraph capture
2450+
assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \
2451+
aclgraph_runtime_mode == _ag_mode, (
24462452
f"Aclgraph runtime mode mismatch at dummy_run. "
2447-
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
2453+
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
2454+
else:
2455+
aclgraph_runtime_mode = _ag_mode
24482456

24492457
need_dummy_logits = (not self.in_profile_run
24502458
and lmhead_tp_enable())

vllm_ascend/worker/worker_v1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import vllm.envs as envs_vllm
2727
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
2828
from torch_npu.profiler import dynamic_profile as dp
29-
from vllm.config import VllmConfig
29+
from vllm.config import CUDAGraphMode, VllmConfig
3030
from vllm.distributed import (ensure_model_parallel_initialized,
3131
init_distributed_environment)
3232
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
@@ -356,7 +356,10 @@ def pin_lora(self, lora_id: int) -> bool:
356356
return self.model_runner.pin_lora(lora_id)
357357

358358
def execute_dummy_batch(self) -> None:
359-
self.model_runner._dummy_run(1)
359+
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
360+
self.model_runner._dummy_run(num_tokens=1,
361+
uniform_decode=True,
362+
force_attention=force_attention)
360363

361364
def _init_worker_distributed_environment(self) -> None:
362365
"""Initialize the distributed environment."""

0 commit comments

Comments
 (0)