Skip to content

Commit afe1a72

Browse files
committed
fix rebase error
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 295346f commit afe1a72

File tree

2 files changed

+35
-42
lines changed

2 files changed

+35
-42
lines changed

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.distributed.device_communicators.base_device_communicator \
99
import DeviceCommunicatorBase
1010
from vllm.forward_context import get_forward_context
11-
from vllm.distributed.parallel_state import get_dp_group
11+
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
1212

1313
import habana_frameworks.torch as htorch # noqa: F401
1414

@@ -22,7 +22,7 @@ def __init__(self,
2222
unique_name: str = ""):
2323
super().__init__(cpu_group, device, device_group, unique_name)
2424

25-
self.dp_group = None
25+
self.dp_group: Optional[GroupCoordinator] = None
2626
self.dp_rank = 0
2727
self.dp_world_size = 1
2828
# assume EP is enabled along with DP
@@ -31,22 +31,6 @@ def __init__(self,
3131
self.dp_rank = self.dp_group.rank_in_group
3232
self.dp_world_size = self.dp_group.world_size
3333

34-
def naive_multicast(self, x: torch.Tensor,
35-
cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor:
36-
assert x.dim() == 2, "Input tensor must be 2D"
37-
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
38-
device=x.device,
39-
dtype=x.dtype)
40-
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
41-
self.dp_rank - 1]
42-
end = cu_tokens_across_dp_cpu[self.dp_rank]
43-
buffer[start:end, :].copy_(x)
44-
for idx in range(self.dp_world_size):
45-
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
46-
end = cu_tokens_across_dp_cpu[idx]
47-
self.dp_group.broadcast(buffer[start:end, :], idx)
48-
return buffer
49-
5034
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
5135
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
5236
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
@@ -81,6 +65,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
8165
def dispatch(
8266
self, hidden_states: torch.Tensor,
8367
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
68+
assert self.dp_group is not None
8469
assert hidden_states.dim() == 2, "Input hidden states must be 2D"
8570
input_size = hidden_states.size()
8671
# Allocate output tensor.
@@ -109,6 +94,7 @@ def dispatch(
10994
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
11095
if htorch.utils.internal.is_lazy():
11196
htorch.core.mark_step()
97+
assert self.dp_group is not None
11298
assert hidden_states.dim() == 2, "Input hidden states must be 2D"
11399
cu_tokens_across_dp_cpu = get_forward_context(
114100
).dp_metadata.cu_tokens_across_dp_cpu

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@ def _form_prefill_batch(self, contents):
15871587

15881588
def _create_dummy_prefill_batch_contents(
15891589
self, num_prefills: int) -> list[PrefillInputData]:
1590-
req_id = -1
1590+
req_id = str(-1)
15911591
context_len = 0
15921592
query_len = 128
15931593
prompt_tokens = 128
@@ -1616,26 +1616,30 @@ def _create_dummy_prefill_batch_contents(
16161616
def _prepare_prefill_inputs(
16171617
self, num_prefills, num_decodes, num_scheduled_tokens: list[int]
16181618
) -> tuple[PrefillInputData, Optional[PrefillInputData]]:
1619-
all_batch_contents, num_pad_across_dp = self._extract_prefill_batch_contents(
1620-
num_prefills, num_decodes, num_scheduled_tokens)
1619+
all_batch_contents, num_pad_across_dp = \
1620+
self._extract_prefill_batch_contents(
1621+
num_prefills, num_decodes, num_scheduled_tokens)
16211622
all_batches = [
16221623
self._form_prefill_batch(bc) for bc in all_batch_contents
16231624
]
16241625
merge_contents(all_batches[0], *all_batches[1:])
16251626

16261627
dummy_prefill_input_batches = None
16271628
if num_pad_across_dp > 0:
1628-
dummy_prefill_input_batches = self._create_dummy_prefill_batch_contents(
1629-
num_pad_across_dp)
1629+
dummy_prefill_input_batches = \
1630+
self._create_dummy_prefill_batch_contents(num_pad_across_dp)
16301631
merge_contents(dummy_prefill_input_batches[0],
16311632
*dummy_prefill_input_batches[1:])
16321633
return all_batches[0], dummy_prefill_input_batches[
16331634
0] if dummy_prefill_input_batches else None
16341635

16351636
def _create_decode_input_data(
1636-
self, num_decodes, num_scheduled_tokens, context_lens,
1637-
block_table_cpu_tensor, num_computed_tokens_cpu,
1638-
token_ids_cpu) -> tuple[DecodeInputData, int]:
1637+
self,
1638+
num_decodes,
1639+
num_scheduled_tokens,
1640+
context_lens,
1641+
block_table_cpu_tensor,
1642+
scheduler_output=None) -> tuple[DecodeInputData, int]:
16391643
# NOTE(kzawora): the +1 is what causes this entire thing to work,
16401644
# as in the paged attention, we don't fetch just the context from cache,
16411645
# but also kvs for the current token
@@ -1842,7 +1846,10 @@ def _create_decode_input_data(
18421846
spec_decode_metadata=spec_decode_metadata), num_pad_across_dp
18431847

18441848
def _prepare_decode_inputs(
1845-
self, num_decodes, num_scheduled_tokens
1849+
self,
1850+
num_decodes,
1851+
num_scheduled_tokens,
1852+
scheduler_output=None
18461853
) -> tuple[DecodeInputData, Optional[DecodeInputData]]:
18471854
# Decodes run as one single padded batch with shape [batch, 1]
18481855
#
@@ -1861,9 +1868,7 @@ def _prepare_decode_inputs(
18611868
return self._create_decode_input_data(
18621869
num_decodes, num_scheduled_tokens,
18631870
self.input_batch.num_computed_tokens_cpu[:num_decodes],
1864-
self.input_batch.block_table[0].get_cpu_tensor(),
1865-
self.input_batch.num_computed_tokens_cpu,
1866-
self.input_batch.token_ids_cpu)
1871+
self.input_batch.block_table[0].get_cpu_tensor(), scheduler_output)
18671872

18681873
def _create_dummy_decode_input_data(self) -> DecodeInputData:
18691874
# create dummy decode input data with batch size 1
@@ -1872,12 +1877,13 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData:
18721877
context_lens = [128]
18731878
block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID],
18741879
dtype=torch.int32).reshape(1, -1)
1875-
num_computed_tokens_cpu = np.array([128], dtype=np.int32)
1876-
token_ids = np.array(list(int(i) for i in range(context_lens[0])))
1880+
# num_computed_tokens_cpu = np.array([128], dtype=np.int32)
1881+
# token_ids = np.array(list(int(i) for i in range(context_lens[0])))
18771882

1878-
return self._create_decode_input_data(
1879-
num_dummy_decodes, num_dummy_scheduled_tokens, context_lens,
1880-
block_table_cpu_tensor, num_computed_tokens_cpu, token_ids)[0]
1883+
return self._create_decode_input_data(num_dummy_decodes,
1884+
num_dummy_scheduled_tokens,
1885+
context_lens,
1886+
block_table_cpu_tensor)[0]
18811887

18821888
def _get_cumsum_and_arange(
18831889
self,
@@ -2052,8 +2058,7 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
20522058
if not seen and not warmup_mode:
20532059
logger.warning("Configuration: %s was not warmed-up!", cfg)
20542060

2055-
def get_dp_padding(self,
2056-
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
2061+
def get_dp_padding(self, num_tokens: int) -> int:
20572062
dp_size = self.vllm_config.parallel_config.data_parallel_size
20582063
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
20592064

@@ -2364,9 +2369,11 @@ def execute_model(
23642369
with self.profiler.record_event('internal', 'prepare_input_tensors'):
23652370
prefill_input_data, decode_input_data = self._prepare_inputs(
23662371
scheduler_output, num_prefills, num_decodes)
2367-
prefill_data, dummy_prefill_input_data_batches_across_dp = prefill_input_data
2368-
num_pad_prefill_batch_across_dp = 0 if dummy_prefill_input_data_batches_across_dp is None else len(
2369-
dummy_prefill_input_data_batches_across_dp.request_ids)
2372+
prefill_data, \
2373+
dummy_prefill_input_data_batches_across_dp = prefill_input_data
2374+
num_pad_prefill_batch_across_dp = \
2375+
0 if dummy_prefill_input_data_batches_across_dp is None \
2376+
else len(dummy_prefill_input_data_batches_across_dp.request_ids)
23702377
decode_data, dummy_decode_input_data_across_dp = decode_input_data
23712378
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
23722379
# later.
@@ -2477,7 +2484,7 @@ def execute_model(
24772484
zip(*shallow_tuple(
24782485
dummy_prefill_input_data_batches_across_dp))):
24792486
htorch.core.mark_step()
2480-
_, dummy_logits_device = \
2487+
_, _, dummy_logits_device = \
24812488
self._execute_model_generic(
24822489
token_ids,
24832490
position_ids,
@@ -2566,7 +2573,7 @@ def execute_model(
25662573
else:
25672574
if dummy_decode_input_data_across_dp is not None:
25682575
htorch.core.mark_step()
2569-
_, dummy_logits_device = self._execute_model_generic(
2576+
_, _, dummy_logits_device = self._execute_model_generic(
25702577
dummy_decode_input_data_across_dp.token_ids,
25712578
dummy_decode_input_data_across_dp.position_ids,
25722579
dummy_decode_input_data_across_dp.attn_metadata,

0 commit comments

Comments
 (0)