Skip to content

Commit 7fdd7dd

Browse files
committed
fix lazy hang
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 4ada515 commit 7fdd7dd

File tree

3 files changed

+305
-88
lines changed

3 files changed

+305
-88
lines changed

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,51 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from typing import Optional
34
import torch
45
import torch.distributed as dist
6+
from torch.distributed import ProcessGroup
57

68
from vllm.distributed.device_communicators.base_device_communicator \
79
import DeviceCommunicatorBase
8-
from vllm.distributed.parallel_state import get_dp_group
910
from vllm.forward_context import get_forward_context
11+
from vllm.distributed.parallel_state import get_dp_group
1012

1113
import habana_frameworks.torch as htorch # noqa: F401
1214

1315

14-
def naive_multicast(x: torch.Tensor,
15-
cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor:
16-
assert x.dim() == 2, "Input tensor must be 2D"
17-
dp_rank = get_dp_group().rank_in_group
18-
dp_world_size = get_dp_group().world_size
19-
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
20-
device=x.device,
21-
dtype=x.dtype)
22-
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1]
23-
end = cu_tokens_across_dp_cpu[dp_rank]
24-
buffer[start:end, :].copy_(x)
25-
for idx in range(dp_world_size):
26-
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
27-
end = cu_tokens_across_dp_cpu[idx]
28-
get_dp_group().broadcast(buffer[start:end, :], idx)
29-
return buffer
16+
class HpuCommunicator(DeviceCommunicatorBase):
3017

18+
def __init__(self,
19+
cpu_group: ProcessGroup,
20+
device: Optional[torch.device] = None,
21+
device_group: Optional[ProcessGroup] = None,
22+
unique_name: str = ""):
23+
super().__init__(cpu_group, device, device_group, unique_name)
3124

32-
class HpuCommunicator(DeviceCommunicatorBase):
25+
self.dp_group = None
26+
self.dp_rank = 0
27+
self.dp_world_size = 1
28+
# assume EP is enabled along with DP
29+
if "ep" in unique_name:
30+
self.dp_group = get_dp_group()
31+
self.dp_rank = self.dp_group.rank_in_group
32+
self.dp_world_size = self.dp_group.world_size
33+
34+
def naive_multicast(self, x: torch.Tensor,
35+
cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor:
36+
assert x.dim() == 2, "Input tensor must be 2D"
37+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
38+
device=x.device,
39+
dtype=x.dtype)
40+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
41+
self.dp_rank - 1]
42+
end = cu_tokens_across_dp_cpu[self.dp_rank]
43+
buffer[start:end, :].copy_(x)
44+
for idx in range(self.dp_world_size):
45+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
46+
end = cu_tokens_across_dp_cpu[idx]
47+
self.dp_group.broadcast(buffer[start:end, :], idx)
48+
return buffer
3349

3450
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
3551
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
@@ -67,19 +83,57 @@ def dispatch(
6783
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
6884
cu_tokens_across_dp_cpu = get_forward_context(
6985
).dp_metadata.cu_tokens_across_dp_cpu
70-
hidden_states_across_dp = naive_multicast(hidden_states,
71-
cu_tokens_across_dp_cpu)
72-
router_logits_across_dp = naive_multicast(router_logits,
73-
cu_tokens_across_dp_cpu)
86+
hidden_states_across_dp = self.naive_multicast(
87+
hidden_states, cu_tokens_across_dp_cpu)
88+
router_logits_across_dp = self.naive_multicast(
89+
router_logits, cu_tokens_across_dp_cpu)
7490
return hidden_states_across_dp, router_logits_across_dp
7591

92+
# def dispatch(
93+
# self, hidden_states: torch.Tensor,
94+
# router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
95+
# world_size = self.dp_group.world_size
96+
# input_size = hidden_states.size()
97+
# # Allocate output tensor.
98+
# output_size = list(input_size)
99+
# output_size[0] *= world_size
100+
# hidden_states_across_dp = torch.empty(output_size,
101+
# dtype=hidden_states.dtype,
102+
# device=hidden_states.device)
103+
# # All-gather.
104+
# torch.distributed.all_gather_into_tensor(
105+
# hidden_states_across_dp, hidden_states, group=self.dp_group.device_group)
106+
107+
# router_logits_size = router_logits.size()
108+
# router_logits_output_size = list(router_logits_size)
109+
# router_logits_output_size[0] *= world_size
110+
# router_logits_across_dp = torch.empty(router_logits_output_size,
111+
# dtype=router_logits.dtype,
112+
# device=router_logits.device)
113+
# # All-gather.
114+
# torch.distributed.all_gather_into_tensor(
115+
# router_logits_across_dp, router_logits, group=self.dp_group.device_group)
116+
# return hidden_states_across_dp, router_logits_across_dp
117+
76118
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
77-
dp_rank = get_dp_group().rank_in_group
119+
if htorch.utils.internal.is_lazy():
120+
htorch.core.mark_step()
78121
cu_tokens_across_dp_cpu = get_forward_context(
79122
).dp_metadata.cu_tokens_across_dp_cpu
80-
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1]
81-
end = cu_tokens_across_dp_cpu[dp_rank]
82123

83-
all_hidden_states = get_dp_group().all_reduce(hidden_states)
124+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
125+
self.dp_rank - 1]
126+
end = cu_tokens_across_dp_cpu[self.dp_rank]
127+
128+
all_hidden_states = self.dp_group.all_reduce(hidden_states)
84129
hidden_states = all_hidden_states[start:end, :]
85130
return hidden_states
131+
132+
# def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
133+
# if htorch.utils.internal.is_lazy():
134+
# htorch.core.mark_step()
135+
136+
# all_hidden_states = self.dp_group.all_reduce(hidden_states)
137+
# all_hidden_states = all_hidden_states.view(self.dp_group.world_size, -1, all_hidden_states.size(-1))
138+
# hidden_states = all_hidden_states[self.dp_rank // self.dp_world_size, :, :]
139+
# return hidden_states

0 commit comments

Comments
 (0)