-
Notifications
You must be signed in to change notification settings - Fork 468
[Disagg][Perf] Use NPU event sync instead of blocking tolist to avoid unintentional copy ops blocking across different NPU streams, improving disagg TTIT/TTFT #3209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b6c5ef9
9816a36
f14a98b
beabae4
3da83fe
1695f5f
c483b20
ed0b72f
1f9cb35
9c8fb4c
5be58d5
598c896
674be75
4588d12
d81f665
dd4c177
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -227,6 +227,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |||||||||||||||||||||
self.block_size = vllm_config.cache_config.block_size | ||||||||||||||||||||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, | ||||||||||||||||||||||
self.block_size) | ||||||||||||||||||||||
self.max_model_len = self.model_config.max_model_len | ||||||||||||||||||||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens | ||||||||||||||||||||||
decode_max_num_seqs = getattr(self.scheduler_config, | ||||||||||||||||||||||
'decode_max_num_seqs', 0) | ||||||||||||||||||||||
|
@@ -401,6 +402,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |||||||||||||||||||||
# Cached outputs. | ||||||||||||||||||||||
self._draft_token_ids: Optional[Union[list[list[int]], | ||||||||||||||||||||||
torch.Tensor]] = None | ||||||||||||||||||||||
self.transfer_event = torch_npu.npu.Event() | ||||||||||||||||||||||
self.sampled_token_ids_pinned_cpu = torch.empty( | ||||||||||||||||||||||
(self.max_model_len, 1), | ||||||||||||||||||||||
dtype=torch.int64, | ||||||||||||||||||||||
device="cpu", | ||||||||||||||||||||||
pin_memory=True) | ||||||||||||||||||||||
Comment on lines
+406
to
+410
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True | ||||||||||||||||||||||
self.in_profile_run = False | ||||||||||||||||||||||
|
@@ -1906,7 +1913,7 @@ def execute_model( | |||||||||||||||||||||
max_gen_len = sampled_token_ids.shape[-1] | ||||||||||||||||||||||
if max_gen_len == 1: | ||||||||||||||||||||||
# No spec decode tokens. | ||||||||||||||||||||||
valid_sampled_token_ids = sampled_token_ids.tolist() | ||||||||||||||||||||||
valid_sampled_token_ids = self._to_list(sampled_token_ids) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
# Includes spec decode tokens. | ||||||||||||||||||||||
valid_sampled_token_ids = self.rejection_sampler.parse_output( | ||||||||||||||||||||||
|
@@ -3054,3 +3061,18 @@ def get_supported_pooling_tasks(self): | |||||||||||||||||||||
|
||||||||||||||||||||||
def _build_drafter_prepare_inputs_torchair_param(self): | ||||||||||||||||||||||
return False | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: | ||||||||||||||||||||||
# This is a short term mitigation for issue mentioned in | ||||||||||||||||||||||
# https://github.com/vllm-project/vllm/issues/22754. | ||||||||||||||||||||||
# `tolist` would trigger a npu wise stream sync, which | ||||||||||||||||||||||
# would block other copy ops from other npu streams. | ||||||||||||||||||||||
# A npu event sync would avoid such a situation. Since | ||||||||||||||||||||||
# this is in the critical path of every single model | ||||||||||||||||||||||
# forward loop, this has caused perf issue for a disagg | ||||||||||||||||||||||
# setup. | ||||||||||||||||||||||
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] | ||||||||||||||||||||||
pinned.copy_(sampled_token_ids, non_blocking=True) | ||||||||||||||||||||||
self.transfer_event.record() | ||||||||||||||||||||||
self.transfer_event.synchronize() | ||||||||||||||||||||||
return pinned.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is fragile because it bypasses
__init__
and duplicates the implementation logic for creatingtransfer_event
andsampled_token_ids_pinned_cpu
within the test body. This makes the test hard to maintain, as changes in__init__
might not be reflected here, leading to the test passing while the actual code is broken, or vice-versa.A better approach is to test the behavior of
__init__
by calling it and asserting the results, while mocking its complex dependencies. Alternatively, the logic for initializing these new attributes could be extracted into a separate helper method withinNPUModelRunner
, which can then be called and tested directly. This would avoid code duplication and make the test more robust.For example, you could refactor the
NPUModelRunner
like this:And the test would become:
This approach tests the logic without duplicating it.