Skip to content

Commit eae75cd

Browse files
committed
use reduce_scatter instead of all_reduce
Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 0fe5711 commit eae75cd

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
8181
def dispatch(
8282
self, hidden_states: torch.Tensor,
8383
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
84+
assert hidden_states.dim() == 2, "Input hidden states must be 2D"
8485
input_size = hidden_states.size()
8586
# Allocate output tensor.
8687
output_size = list(input_size)
@@ -108,13 +109,22 @@ def dispatch(
108109
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
109110
if htorch.utils.internal.is_lazy():
110111
htorch.core.mark_step()
112+
assert hidden_states.dim() == 2, "Input hidden states must be 2D"
111113
cu_tokens_across_dp_cpu = get_forward_context(
112114
).dp_metadata.cu_tokens_across_dp_cpu
113115

114-
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
115-
self.dp_rank - 1]
116-
end = cu_tokens_across_dp_cpu[self.dp_rank]
116+
# assume num tokens is padded across DP ranks
117+
assert cu_tokens_across_dp_cpu[
118+
0] * self.dp_world_size == cu_tokens_across_dp_cpu[-1]
117119

118-
all_hidden_states = self.dp_group.all_reduce(hidden_states)
119-
hidden_states = all_hidden_states[start:end, :]
120+
local_hidden_states = torch.empty(
121+
(cu_tokens_across_dp_cpu[0], hidden_states.size(-1)),
122+
device=hidden_states.device,
123+
dtype=hidden_states.dtype)
124+
125+
torch.distributed.reduce_scatter_tensor(
126+
local_hidden_states,
127+
hidden_states,
128+
group=self.dp_group.device_group)
129+
hidden_states = local_hidden_states
120130
return hidden_states

0 commit comments

Comments
 (0)