Skip to content

Commit 0799440

Browse files
committed
fix lazy hang
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 6cb02a5 commit 0799440

File tree

3 files changed

+166
-87
lines changed

3 files changed

+166
-87
lines changed

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,51 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from typing import Optional
34
import torch
45
import torch.distributed as dist
6+
from torch.distributed import ProcessGroup
57

68
from vllm.distributed.device_communicators.base_device_communicator \
79
import DeviceCommunicatorBase
8-
from vllm.distributed.parallel_state import get_dp_group
910
from vllm.forward_context import get_forward_context
11+
from vllm.distributed.parallel_state import get_dp_group
1012

1113
import habana_frameworks.torch as htorch # noqa: F401
1214

1315

14-
def naive_multicast(x: torch.Tensor,
15-
cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor:
16-
assert x.dim() == 2, "Input tensor must be 2D"
17-
dp_rank = get_dp_group().rank_in_group
18-
dp_world_size = get_dp_group().world_size
19-
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
20-
device=x.device,
21-
dtype=x.dtype)
22-
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1]
23-
end = cu_tokens_across_dp_cpu[dp_rank]
24-
buffer[start:end, :].copy_(x)
25-
for idx in range(dp_world_size):
26-
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
27-
end = cu_tokens_across_dp_cpu[idx]
28-
get_dp_group().broadcast(buffer[start:end, :], idx)
29-
return buffer
16+
class HpuCommunicator(DeviceCommunicatorBase):
17+
18+
def __init__(self,
19+
cpu_group: ProcessGroup,
20+
device: Optional[torch.device] = None,
21+
device_group: Optional[ProcessGroup] = None,
22+
unique_name: str = ""):
23+
super().__init__(cpu_group, device, device_group, unique_name)
3024

25+
self.dp_group = None
26+
self.dp_rank = 0
27+
self.dp_world_size = 1
28+
# assume EP is enabled along with DP
29+
if "ep" in unique_name:
30+
self.dp_group = get_dp_group()
31+
self.dp_rank = self.dp_group.rank_in_group
32+
self.dp_world_size = self.dp_group.world_size
3133

32-
class HpuCommunicator(DeviceCommunicatorBase):
34+
def naive_multicast(self, x: torch.Tensor,
35+
cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor:
36+
assert x.dim() == 2, "Input tensor must be 2D"
37+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
38+
device=x.device,
39+
dtype=x.dtype)
40+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
41+
self.dp_rank - 1]
42+
end = cu_tokens_across_dp_cpu[self.dp_rank]
43+
buffer[start:end, :].copy_(x)
44+
for idx in range(self.dp_world_size):
45+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
46+
end = cu_tokens_across_dp_cpu[idx]
47+
self.dp_group.broadcast(buffer[start:end, :], idx)
48+
return buffer
3349

3450
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
3551
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
@@ -67,19 +83,22 @@ def dispatch(
6783
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
6884
cu_tokens_across_dp_cpu = get_forward_context(
6985
).dp_metadata.cu_tokens_across_dp_cpu
70-
hidden_states_across_dp = naive_multicast(hidden_states,
71-
cu_tokens_across_dp_cpu)
72-
router_logits_across_dp = naive_multicast(router_logits,
73-
cu_tokens_across_dp_cpu)
86+
hidden_states_across_dp = self.naive_multicast(
87+
hidden_states, cu_tokens_across_dp_cpu)
88+
router_logits_across_dp = self.naive_multicast(
89+
router_logits, cu_tokens_across_dp_cpu)
7490
return hidden_states_across_dp, router_logits_across_dp
7591

7692
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
77-
dp_rank = get_dp_group().rank_in_group
93+
if htorch.utils.internal.is_lazy():
94+
htorch.core.mark_step()
7895
cu_tokens_across_dp_cpu = get_forward_context(
7996
).dp_metadata.cu_tokens_across_dp_cpu
80-
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1]
81-
end = cu_tokens_across_dp_cpu[dp_rank]
8297

83-
all_hidden_states = get_dp_group().all_reduce(hidden_states)
98+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
99+
self.dp_rank - 1]
100+
end = cu_tokens_across_dp_cpu[self.dp_rank]
101+
102+
all_hidden_states = self.dp_group.all_reduce(hidden_states)
84103
hidden_states = all_hidden_states[start:end, :]
85104
return hidden_states

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 117 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,11 +1354,8 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes,
13541354
# no real prefill batches
13551355
num_prefill_batches = 0
13561356

1357-
num_pad = self.get_dp_padding(num_prefill_batches)
1358-
if num_pad > 0:
1359-
for _ in range(num_pad):
1360-
all_batch_contents.append(BatchContents())
1361-
return all_batch_contents
1357+
num_pad_across_dp = self.get_dp_padding(num_prefill_batches)
1358+
return all_batch_contents, num_pad_across_dp
13621359

13631360
def _make_attn_bias(self, context_groups, token_groups):
13641361
dtype = self.dtype
@@ -1426,11 +1423,6 @@ def _form_prefill_batch(self, contents):
14261423
target_bs, target_seq, target_blocks = self._get_prompt_bucketing_fn()(
14271424
query_lens, num_context_blocks)
14281425

1429-
# dp aware padding
1430-
target_bs += self.get_dp_padding(target_bs)
1431-
target_seq += self.get_dp_padding(target_seq)
1432-
target_blocks += self.get_dp_padding(target_blocks)
1433-
14341426
# NOTE: If model does not support multimodal inputs, we pad here.
14351427
# For models with multimodal support, we may want to get embeddings
14361428
# for the valid tokens before padding.
@@ -1523,33 +1515,49 @@ def _form_prefill_batch(self, contents):
15231515
logits_indices=[logits_indices],
15241516
logits_requests=[logits_requests])
15251517

1518+
def _create_dummy_prefill_batch_contents(
1519+
self, num_prefills: int) -> list[PrefillInputData]:
1520+
req_id = -1
1521+
context_len = 0
1522+
query_len = 128
1523+
prompt_tokens = 128
1524+
token_ids = list(int(i) for i in range(prompt_tokens))
1525+
num_blocks = round_up(context_len + query_len,
1526+
self.block_size) // self.block_size
1527+
blocks = [0] * num_blocks
1528+
num_output_logits = context_len + query_len - prompt_tokens + 1
1529+
logits_positions = list(range(query_len - num_output_logits,
1530+
query_len))
1531+
1532+
new_batch_contents = BatchContents(
1533+
req_ids=[req_id],
1534+
token_ids=[token_ids],
1535+
context_lens=[context_len],
1536+
blocks=[blocks],
1537+
logits_positions=[logits_positions],
1538+
)
1539+
1540+
outputs = [
1541+
self._form_prefill_batch(new_batch_contents)
1542+
for _ in range(num_prefills)
1543+
]
1544+
return outputs
1545+
15261546
def _prepare_prefill_inputs(
15271547
self, num_prefills, num_decodes,
1528-
num_scheduled_tokens: list[int]) -> PrefillInputData:
1529-
all_batch_contents = self._extract_prefill_batch_contents(
1548+
num_scheduled_tokens: list[int]) -> tuple[PrefillInputData, int]:
1549+
all_batch_contents, num_pad_across_dp = self._extract_prefill_batch_contents(
15301550
num_prefills, num_decodes, num_scheduled_tokens)
15311551
all_batches = [
15321552
self._form_prefill_batch(bc) for bc in all_batch_contents
15331553
]
15341554
merge_contents(all_batches[0], *all_batches[1:])
1535-
return all_batches[0]
1536-
1537-
def _prepare_decode_inputs(self, num_decodes,
1538-
num_scheduled_tokens) -> DecodeInputData:
1539-
# Decodes run as one single padded batch with shape [batch, 1]
1540-
#
1541-
# We need to set _PAD_SLOT_ID for the padding tokens in the
1542-
# slot_mapping, such that the attention KV cache insertion
1543-
# logic knows to ignore those indicies. Otherwise, the
1544-
# padding data can be dummy since we have a causal mask.
1545-
1546-
block_table_cpu_tensor = self.input_batch.block_table[
1547-
0].get_cpu_tensor()
1548-
if num_decodes == 0:
1549-
return DecodeInputData(num_decodes=0)
1550-
# BLOCK_TABLE [batch, max_num_blocks_per_req]
1551-
context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes]
1555+
return all_batches[0], num_pad_across_dp
15521556

1557+
def _create_decode_input_data(
1558+
self, num_decodes, num_scheduled_tokens, context_lens,
1559+
block_table_cpu_tensor, num_computed_tokens_cpu,
1560+
token_ids_cpu) -> tuple[DecodeInputData, int]:
15531561
# NOTE(kzawora): the +1 is what causes this entire thing to work,
15541562
# as in the paged attention, we don't fetch just the context from cache,
15551563
# but also kvs for the current token
@@ -1561,8 +1569,9 @@ def _prepare_decode_inputs(self, num_decodes,
15611569
padded_batch_size = self.bucketing_manager.find_decode_bucket(
15621570
num_decodes, sum(num_blocks))[0]
15631571

1564-
# # dp aware padding
1565-
padded_batch_size += self.get_dp_padding(padded_batch_size)
1572+
# dp aware padding
1573+
num_pad_across_dp = self.get_dp_padding(padded_batch_size)
1574+
padded_batch_size += num_pad_across_dp
15661575

15671576
block_tables_list = []
15681577
for i, n in enumerate(num_blocks):
@@ -1574,8 +1583,7 @@ def _prepare_decode_inputs(self, num_decodes,
15741583
# We slice at the end, since we use the positions for gathering.
15751584
positions = torch.zeros((padded_batch_size, 1), dtype=torch.int32)
15761585
positions[:num_decodes] = torch.from_numpy(
1577-
self.input_batch.num_computed_tokens_cpu.reshape(-1,
1578-
1)[:num_decodes])
1586+
num_computed_tokens_cpu.reshape(-1, 1)[:num_decodes])
15791587
positions = positions[:padded_batch_size]
15801588

15811589
padded_index = torch.zeros((padded_batch_size, 1), dtype=torch.int64)
@@ -1613,11 +1621,8 @@ def _prepare_decode_inputs(self, num_decodes,
16131621

16141622
# TOKEN_IDS. [batch, 1]
16151623
token_ids = torch.zeros((padded_batch_size, 1), dtype=torch.int32)
1616-
token_ids[:num_decodes] = torch.gather(input=torch.from_numpy(
1617-
self.input_batch.token_ids_cpu),
1618-
dim=1,
1619-
index=index)
1620-
1624+
token_ids[:num_decodes] = torch.gather(
1625+
input=torch.from_numpy(token_ids_cpu), dim=1, index=index)
16211626
# SLOT_MAPPING [batch, 1]
16221627
# The "slot" is the "physical index" of a token in the KV cache.
16231628
# Look up the block_idx in the block table (logical<>physical map)
@@ -1684,7 +1689,42 @@ def _prepare_decode_inputs(self, num_decodes,
16841689
num_decode_tokens=num_decode_tokens_device,
16851690
slot_mapping=slot_mapping_device,
16861691
block_size=self.block_size,
1687-
))
1692+
)), num_pad_across_dp
1693+
1694+
def _prepare_decode_inputs(
1695+
self, num_decodes,
1696+
num_scheduled_tokens) -> tuple[DecodeInputData, int]:
1697+
# Decodes run as one single padded batch with shape [batch, 1]
1698+
#
1699+
# We need to set _PAD_SLOT_ID for the padding tokens in the
1700+
# slot_mapping, such that the attention KV cache insertion
1701+
# logic knows to ignore those indicies. Otherwise, the
1702+
# padding data can be dummy since we have a causal mask.
1703+
1704+
num_pad_across_dp = self.get_dp_padding(num_decodes)
1705+
if num_decodes == 0:
1706+
return DecodeInputData(num_decodes=0), num_pad_across_dp
1707+
# BLOCK_TABLE [batch, max_num_blocks_per_req]
1708+
context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes]
1709+
block_table_cpu_tensor = self.input_batch.block_table[
1710+
0].get_cpu_tensor()
1711+
return self._create_decode_input_data(
1712+
num_decodes, num_scheduled_tokens, context_lens,
1713+
block_table_cpu_tensor, self.input_batch.num_computed_tokens_cpu,
1714+
self.input_batch.token_ids_cpu)
1715+
1716+
def _create_dummy_decode_input_data(self) -> DecodeInputData:
1717+
# create dummy decode input data with batch size 1
1718+
context_lens = [128]
1719+
block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID],
1720+
dtype=torch.int32).reshape(1, -1)
1721+
num_computed_tokens_cpu = np.array([128], dtype=np.int32)
1722+
token_ids = np.array(list(int(i) for i in range(context_lens[0])))
1723+
1724+
return self._create_decode_input_data(1, [1], context_lens,
1725+
block_table_cpu_tensor,
1726+
num_computed_tokens_cpu,
1727+
token_ids)[0]
16881728

16891729
def _prepare_inputs(
16901730
self,
@@ -1740,18 +1780,7 @@ def get_dp_padding(self,
17401780
dp_size = self.vllm_config.parallel_config.data_parallel_size
17411781
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
17421782

1743-
# For DP: Don't pad when setting enforce_eager.
1744-
# This lets us set enforce_eager on the prefiller in a P/D setup and
1745-
# still use CUDA graphs (enabled by this padding) on the decoder.
1746-
#
1747-
# TODO(tms) : There are many cases where padding is enabled for
1748-
# prefills, causing unnecessary and excessive padding of activations.
1749-
1750-
# skip padding for non PD disagg case to avoid padding on prefill batch
1751-
# size and decode batch size
1752-
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or (
1753-
self.vllm_config.kv_transfer_config is None
1754-
or self.vllm_config.kv_transfer_config.kv_connector is None):
1783+
if dp_size == 1:
17551784
return 0
17561785

17571786
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
@@ -1768,7 +1797,6 @@ def _execute_model_generic(self,
17681797
warmup_mode=False,
17691798
inputs_embeds=None,
17701799
model_mm_kwargs=None):
1771-
17721800
# FORWARD.
17731801
batch_size = token_ids.size(0)
17741802
seq_len = self._seq_len(attn_metadata)
@@ -2057,8 +2085,10 @@ def execute_model(
20572085
num_prefills = len(pd_info.prompt_req_ids)
20582086
num_reqs = num_decodes + num_prefills
20592087
with self.profiler.record_event('internal', 'prepare_input_tensors'):
2060-
prefill_data, decode_data = self._prepare_inputs(
2088+
prefill_input_data, decode_input_data = self._prepare_inputs(
20612089
scheduler_output, num_prefills, num_decodes)
2090+
prefill_data, num_pad_prefill_batch_across_dp = prefill_input_data
2091+
decode_data, num_pad_decode_batch_across_dp = decode_input_data
20622092
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
20632093
# later.
20642094
prefill_sampled_token_ids = []
@@ -2124,6 +2154,7 @@ def execute_model(
21242154
model_mm_kwargs=model_mm_kwargs,
21252155
warmup_mode=warmup_mode)
21262156
htorch.core.mark_step()
2157+
21272158
# Skip separate sampling for structured output
21282159
if structured_output:
21292160
logits_prompt.append(logits_device)
@@ -2154,9 +2185,27 @@ def execute_model(
21542185
prompt_batch_idx=idx,
21552186
is_prompt=True)
21562187
self.profiler.record_counter(self.event_start, counters)
2188+
21572189
if self.is_driver_worker and self.profiler.enabled:
21582190
self.profiler_counter_helper.reset_prompt_seq_stats()
21592191

2192+
else:
2193+
if num_pad_prefill_batch_across_dp > 0:
2194+
htorch.core.mark_step()
2195+
dummy_prefill_input_data_list = self._create_dummy_prefill_batch_contents(
2196+
num_pad_prefill_batch_across_dp)
2197+
for dummy_prefill_input_data in dummy_prefill_input_data_list:
2198+
htorch.core.mark_step()
2199+
_, dummy_logits_device = \
2200+
self._execute_model_generic(
2201+
dummy_prefill_input_data.token_ids[0],
2202+
dummy_prefill_input_data.position_ids[0],
2203+
dummy_prefill_input_data.attn_metadata[0],
2204+
dummy_prefill_input_data.logits_indices[0],
2205+
self.kv_caches,
2206+
warmup_mode=warmup_mode)
2207+
htorch.core.mark_step()
2208+
21602209
######################### DECODES #########################
21612210
# Decodes run as one single batch with [padded_decode_bs, 1]
21622211
if num_decodes > 0:
@@ -2205,6 +2254,19 @@ def execute_model(
22052254
prompt_batch_idx=None,
22062255
is_prompt=False)
22072256
self.profiler.record_counter(self.event_start, counters)
2257+
else:
2258+
if num_pad_decode_batch_across_dp > 0:
2259+
dummy_decode_input_data = self._create_dummy_decode_input_data(
2260+
)
2261+
htorch.core.mark_step()
2262+
_, dummy_logits_device = self._execute_model_generic(
2263+
dummy_decode_input_data.token_ids,
2264+
dummy_decode_input_data.position_ids,
2265+
dummy_decode_input_data.attn_metadata,
2266+
dummy_decode_input_data.logits_indices,
2267+
self.kv_caches,
2268+
warmup_mode=warmup_mode)
2269+
htorch.core.mark_step()
22082270

22092271
if structured_output:
22102272
# Scheduler places cached before prompt
@@ -2315,6 +2377,7 @@ def execute_model(
23152377
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
23162378
pooler_output=[],
23172379
)
2380+
23182381
return model_runner_output
23192382

23202383
def load_model(self) -> None:
@@ -2735,8 +2798,8 @@ def __del__(self):
27352798

27362799
@torch.inference_mode()
27372800
def profile_run(self) -> None:
2801+
return
27382802
"""Profile to measure peak memory during forward pass."""
2739-
27402803
# use an empty tensor instead of `None`` to force Dynamo to pass
27412804
# it by reference, rather by specializing on the value `None`.
27422805
# the `dtype` argument does not matter, and we use `float32` as
@@ -2750,7 +2813,6 @@ def profile_run(self) -> None:
27502813
max_seq_len = math.ceil(
27512814
(self.max_num_tokens // self.max_prefill_batch_size) /
27522815
self.block_size) * self.block_size
2753-
max_seq_len = min(max_seq_len, self.max_model_len)
27542816
self._execute_dummy_scenario(
27552817
(self.max_prefill_batch_size, max_seq_len, 0), None)
27562818

0 commit comments

Comments
 (0)