Skip to content

Commit 86c8e41

Browse files
committed
add dp padding for prefill bs/seqlen/blocks
1 parent 0799440 commit 86c8e41

File tree

2 files changed

+73
-41
lines changed

2 files changed

+73
-41
lines changed

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,28 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
8181
def dispatch(
8282
self, hidden_states: torch.Tensor,
8383
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
84-
cu_tokens_across_dp_cpu = get_forward_context(
85-
).dp_metadata.cu_tokens_across_dp_cpu
86-
hidden_states_across_dp = self.naive_multicast(
87-
hidden_states, cu_tokens_across_dp_cpu)
88-
router_logits_across_dp = self.naive_multicast(
89-
router_logits, cu_tokens_across_dp_cpu)
84+
input_size = hidden_states.size()
85+
# Allocate output tensor.
86+
output_size = list(input_size)
87+
output_size[0] *= self.dp_world_size
88+
hidden_states_across_dp = torch.empty(output_size,
89+
dtype=hidden_states.dtype,
90+
device=hidden_states.device)
91+
torch.distributed.all_gather_into_tensor(
92+
hidden_states_across_dp,
93+
hidden_states,
94+
group=self.dp_group.device_group)
95+
96+
router_logits_size = router_logits.size()
97+
router_logits_output_size = list(router_logits_size)
98+
router_logits_output_size[0] *= self.dp_world_size
99+
router_logits_across_dp = torch.empty(router_logits_output_size,
100+
dtype=router_logits.dtype,
101+
device=router_logits.device)
102+
torch.distributed.all_gather_into_tensor(
103+
router_logits_across_dp,
104+
router_logits,
105+
group=self.dp_group.device_group)
90106
return hidden_states_across_dp, router_logits_across_dp
91107

92108
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,10 @@ def _form_prefill_batch(self, contents):
14231423
target_bs, target_seq, target_blocks = self._get_prompt_bucketing_fn()(
14241424
query_lens, num_context_blocks)
14251425

1426+
target_bs += self.get_dp_padding(target_bs)
1427+
target_seq += self.get_dp_padding(target_seq)
1428+
target_blocks += self.get_dp_padding(target_blocks)
1429+
14261430
# NOTE: If model does not support multimodal inputs, we pad here.
14271431
# For models with multimodal support, we may want to get embeddings
14281432
# for the valid tokens before padding.
@@ -1544,15 +1548,23 @@ def _create_dummy_prefill_batch_contents(
15441548
return outputs
15451549

15461550
def _prepare_prefill_inputs(
1547-
self, num_prefills, num_decodes,
1548-
num_scheduled_tokens: list[int]) -> tuple[PrefillInputData, int]:
1551+
self, num_prefills, num_decodes, num_scheduled_tokens: list[int]
1552+
) -> tuple[PrefillInputData, Optional[PrefillInputData]]:
15491553
all_batch_contents, num_pad_across_dp = self._extract_prefill_batch_contents(
15501554
num_prefills, num_decodes, num_scheduled_tokens)
15511555
all_batches = [
15521556
self._form_prefill_batch(bc) for bc in all_batch_contents
15531557
]
15541558
merge_contents(all_batches[0], *all_batches[1:])
1555-
return all_batches[0], num_pad_across_dp
1559+
1560+
dummy_prefill_input_batches = None
1561+
if num_pad_across_dp > 0:
1562+
dummy_prefill_input_batches = self._create_dummy_prefill_batch_contents(
1563+
num_pad_across_dp)
1564+
merge_contents(dummy_prefill_input_batches[0],
1565+
*dummy_prefill_input_batches[1:])
1566+
return all_batches[0], dummy_prefill_input_batches[
1567+
0] if dummy_prefill_input_batches else None
15561568

15571569
def _create_decode_input_data(
15581570
self, num_decodes, num_scheduled_tokens, context_lens,
@@ -1692,8 +1704,8 @@ def _create_decode_input_data(
16921704
)), num_pad_across_dp
16931705

16941706
def _prepare_decode_inputs(
1695-
self, num_decodes,
1696-
num_scheduled_tokens) -> tuple[DecodeInputData, int]:
1707+
self, num_decodes, num_scheduled_tokens
1708+
) -> tuple[DecodeInputData, Optional[DecodeInputData]]:
16971709
# Decodes run as one single padded batch with shape [batch, 1]
16981710
#
16991711
# We need to set _PAD_SLOT_ID for the padding tokens in the
@@ -1703,35 +1715,38 @@ def _prepare_decode_inputs(
17031715

17041716
num_pad_across_dp = self.get_dp_padding(num_decodes)
17051717
if num_decodes == 0:
1706-
return DecodeInputData(num_decodes=0), num_pad_across_dp
1707-
# BLOCK_TABLE [batch, max_num_blocks_per_req]
1708-
context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes]
1709-
block_table_cpu_tensor = self.input_batch.block_table[
1710-
0].get_cpu_tensor()
1718+
if num_pad_across_dp > 0:
1719+
dummy_decode_input_data = self._create_dummy_decode_input_data(
1720+
)
1721+
return DecodeInputData(num_decodes=0), dummy_decode_input_data
1722+
return DecodeInputData(num_decodes=0), None
17111723
return self._create_decode_input_data(
1712-
num_decodes, num_scheduled_tokens, context_lens,
1713-
block_table_cpu_tensor, self.input_batch.num_computed_tokens_cpu,
1724+
num_decodes, num_scheduled_tokens,
1725+
self.input_batch.num_computed_tokens_cpu[:num_decodes],
1726+
self.input_batch.block_table[0].get_cpu_tensor(),
1727+
self.input_batch.num_computed_tokens_cpu,
17141728
self.input_batch.token_ids_cpu)
17151729

17161730
def _create_dummy_decode_input_data(self) -> DecodeInputData:
17171731
# create dummy decode input data with batch size 1
1732+
num_dummy_decodes = 1
1733+
num_dummy_scheduled_tokens = [1]
17181734
context_lens = [128]
17191735
block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID],
17201736
dtype=torch.int32).reshape(1, -1)
17211737
num_computed_tokens_cpu = np.array([128], dtype=np.int32)
17221738
token_ids = np.array(list(int(i) for i in range(context_lens[0])))
17231739

1724-
return self._create_decode_input_data(1, [1], context_lens,
1725-
block_table_cpu_tensor,
1726-
num_computed_tokens_cpu,
1727-
token_ids)[0]
1740+
return self._create_decode_input_data(
1741+
num_dummy_decodes, num_dummy_scheduled_tokens, context_lens,
1742+
block_table_cpu_tensor, num_computed_tokens_cpu, token_ids)[0]
17281743

17291744
def _prepare_inputs(
17301745
self,
17311746
scheduler_output: "SchedulerOutput",
17321747
num_prefills,
17331748
num_decodes,
1734-
) -> tuple[PrefillInputData, Optional[DecodeInputData]]:
1749+
):
17351750

17361751
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
17371752
assert total_num_scheduled_tokens > 0
@@ -2087,8 +2102,10 @@ def execute_model(
20872102
with self.profiler.record_event('internal', 'prepare_input_tensors'):
20882103
prefill_input_data, decode_input_data = self._prepare_inputs(
20892104
scheduler_output, num_prefills, num_decodes)
2090-
prefill_data, num_pad_prefill_batch_across_dp = prefill_input_data
2091-
decode_data, num_pad_decode_batch_across_dp = decode_input_data
2105+
prefill_data, dummy_prefill_input_data_batches_across_dp = prefill_input_data
2106+
num_pad_prefill_batch_across_dp = 0 if dummy_prefill_input_data_batches_across_dp is None else len(
2107+
dummy_prefill_input_data_batches_across_dp.request_ids)
2108+
decode_data, dummy_decode_input_data_across_dp = decode_input_data
20922109
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
20932110
# later.
20942111
prefill_sampled_token_ids = []
@@ -2154,7 +2171,6 @@ def execute_model(
21542171
model_mm_kwargs=model_mm_kwargs,
21552172
warmup_mode=warmup_mode)
21562173
htorch.core.mark_step()
2157-
21582174
# Skip separate sampling for structured output
21592175
if structured_output:
21602176
logits_prompt.append(logits_device)
@@ -2191,17 +2207,19 @@ def execute_model(
21912207

21922208
else:
21932209
if num_pad_prefill_batch_across_dp > 0:
2194-
htorch.core.mark_step()
2195-
dummy_prefill_input_data_list = self._create_dummy_prefill_batch_contents(
2196-
num_pad_prefill_batch_across_dp)
2197-
for dummy_prefill_input_data in dummy_prefill_input_data_list:
2210+
for idx, (
2211+
req_id, prompt_len, token_ids, position_ids,
2212+
attn_metadata, logits_indices,
2213+
logits_requests) in enumerate(
2214+
zip(*shallow_tuple(
2215+
dummy_prefill_input_data_batches_across_dp))):
21982216
htorch.core.mark_step()
21992217
_, dummy_logits_device = \
22002218
self._execute_model_generic(
2201-
dummy_prefill_input_data.token_ids[0],
2202-
dummy_prefill_input_data.position_ids[0],
2203-
dummy_prefill_input_data.attn_metadata[0],
2204-
dummy_prefill_input_data.logits_indices[0],
2219+
token_ids,
2220+
position_ids,
2221+
attn_metadata,
2222+
logits_indices,
22052223
self.kv_caches,
22062224
warmup_mode=warmup_mode)
22072225
htorch.core.mark_step()
@@ -2255,15 +2273,13 @@ def execute_model(
22552273
is_prompt=False)
22562274
self.profiler.record_counter(self.event_start, counters)
22572275
else:
2258-
if num_pad_decode_batch_across_dp > 0:
2259-
dummy_decode_input_data = self._create_dummy_decode_input_data(
2260-
)
2276+
if dummy_decode_input_data_across_dp is not None:
22612277
htorch.core.mark_step()
22622278
_, dummy_logits_device = self._execute_model_generic(
2263-
dummy_decode_input_data.token_ids,
2264-
dummy_decode_input_data.position_ids,
2265-
dummy_decode_input_data.attn_metadata,
2266-
dummy_decode_input_data.logits_indices,
2279+
dummy_decode_input_data_across_dp.token_ids,
2280+
dummy_decode_input_data_across_dp.position_ids,
2281+
dummy_decode_input_data_across_dp.attn_metadata,
2282+
dummy_decode_input_data_across_dp.logits_indices,
22672283
self.kv_caches,
22682284
warmup_mode=warmup_mode)
22692285
htorch.core.mark_step()

0 commit comments

Comments
 (0)