|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
| 3 | +from typing import Optional |
3 | 4 | import torch
|
4 | 5 | import torch.distributed as dist
|
| 6 | +from torch.distributed import ProcessGroup |
5 | 7 |
|
6 | 8 | from vllm.distributed.device_communicators.base_device_communicator \
|
7 | 9 | import DeviceCommunicatorBase
|
8 |
| -from vllm.distributed.parallel_state import get_dp_group |
9 | 10 | from vllm.forward_context import get_forward_context
|
| 11 | +from vllm.distributed.parallel_state import get_dp_group |
10 | 12 |
|
11 | 13 | import habana_frameworks.torch as htorch # noqa: F401
|
12 | 14 |
|
13 | 15 |
|
14 |
| -def naive_multicast(x: torch.Tensor, |
15 |
| - cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor: |
16 |
| - assert x.dim() == 2, "Input tensor must be 2D" |
17 |
| - dp_rank = get_dp_group().rank_in_group |
18 |
| - dp_world_size = get_dp_group().world_size |
19 |
| - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), |
20 |
| - device=x.device, |
21 |
| - dtype=x.dtype) |
22 |
| - start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1] |
23 |
| - end = cu_tokens_across_dp_cpu[dp_rank] |
24 |
| - buffer[start:end, :].copy_(x) |
25 |
| - for idx in range(dp_world_size): |
26 |
| - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] |
27 |
| - end = cu_tokens_across_dp_cpu[idx] |
28 |
| - get_dp_group().broadcast(buffer[start:end, :], idx) |
29 |
| - return buffer |
| 16 | +class HpuCommunicator(DeviceCommunicatorBase): |
30 | 17 |
|
| 18 | + def __init__(self, |
| 19 | + cpu_group: ProcessGroup, |
| 20 | + device: Optional[torch.device] = None, |
| 21 | + device_group: Optional[ProcessGroup] = None, |
| 22 | + unique_name: str = ""): |
| 23 | + super().__init__(cpu_group, device, device_group, unique_name) |
31 | 24 |
|
32 |
| -class HpuCommunicator(DeviceCommunicatorBase): |
| 25 | + self.dp_group = None |
| 26 | + self.dp_rank = 0 |
| 27 | + self.dp_world_size = 1 |
| 28 | + # assume EP is enabled along with DP |
| 29 | + if "ep" in unique_name: |
| 30 | + self.dp_group = get_dp_group() |
| 31 | + self.dp_rank = self.dp_group.rank_in_group |
| 32 | + self.dp_world_size = self.dp_group.world_size |
| 33 | + |
| 34 | + def naive_multicast(self, x: torch.Tensor, |
| 35 | + cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor: |
| 36 | + assert x.dim() == 2, "Input tensor must be 2D" |
| 37 | + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), |
| 38 | + device=x.device, |
| 39 | + dtype=x.dtype) |
| 40 | + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ |
| 41 | + self.dp_rank - 1] |
| 42 | + end = cu_tokens_across_dp_cpu[self.dp_rank] |
| 43 | + buffer[start:end, :].copy_(x) |
| 44 | + for idx in range(self.dp_world_size): |
| 45 | + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] |
| 46 | + end = cu_tokens_across_dp_cpu[idx] |
| 47 | + self.dp_group.broadcast(buffer[start:end, :], idx) |
| 48 | + return buffer |
33 | 49 |
|
34 | 50 | def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
35 | 51 | # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
@@ -67,19 +83,57 @@ def dispatch(
|
67 | 83 | router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
68 | 84 | cu_tokens_across_dp_cpu = get_forward_context(
|
69 | 85 | ).dp_metadata.cu_tokens_across_dp_cpu
|
70 |
| - hidden_states_across_dp = naive_multicast(hidden_states, |
71 |
| - cu_tokens_across_dp_cpu) |
72 |
| - router_logits_across_dp = naive_multicast(router_logits, |
73 |
| - cu_tokens_across_dp_cpu) |
| 86 | + hidden_states_across_dp = self.naive_multicast( |
| 87 | + hidden_states, cu_tokens_across_dp_cpu) |
| 88 | + router_logits_across_dp = self.naive_multicast( |
| 89 | + router_logits, cu_tokens_across_dp_cpu) |
74 | 90 | return hidden_states_across_dp, router_logits_across_dp
|
75 | 91 |
|
| 92 | + # def dispatch( |
| 93 | + # self, hidden_states: torch.Tensor, |
| 94 | + # router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 95 | + # world_size = self.dp_group.world_size |
| 96 | + # input_size = hidden_states.size() |
| 97 | + # # Allocate output tensor. |
| 98 | + # output_size = list(input_size) |
| 99 | + # output_size[0] *= world_size |
| 100 | + # hidden_states_across_dp = torch.empty(output_size, |
| 101 | + # dtype=hidden_states.dtype, |
| 102 | + # device=hidden_states.device) |
| 103 | + # # All-gather. |
| 104 | + # torch.distributed.all_gather_into_tensor( |
| 105 | + # hidden_states_across_dp, hidden_states, group=self.dp_group.device_group) |
| 106 | + |
| 107 | + # router_logits_size = router_logits.size() |
| 108 | + # router_logits_output_size = list(router_logits_size) |
| 109 | + # router_logits_output_size[0] *= world_size |
| 110 | + # router_logits_across_dp = torch.empty(router_logits_output_size, |
| 111 | + # dtype=router_logits.dtype, |
| 112 | + # device=router_logits.device) |
| 113 | + # # All-gather. |
| 114 | + # torch.distributed.all_gather_into_tensor( |
| 115 | + # router_logits_across_dp, router_logits, group=self.dp_group.device_group) |
| 116 | + # return hidden_states_across_dp, router_logits_across_dp |
| 117 | + |
76 | 118 | def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
77 |
| - dp_rank = get_dp_group().rank_in_group |
| 119 | + if htorch.utils.internal.is_lazy(): |
| 120 | + htorch.core.mark_step() |
78 | 121 | cu_tokens_across_dp_cpu = get_forward_context(
|
79 | 122 | ).dp_metadata.cu_tokens_across_dp_cpu
|
80 |
| - start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1] |
81 |
| - end = cu_tokens_across_dp_cpu[dp_rank] |
82 | 123 |
|
83 |
| - all_hidden_states = get_dp_group().all_reduce(hidden_states) |
| 124 | + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ |
| 125 | + self.dp_rank - 1] |
| 126 | + end = cu_tokens_across_dp_cpu[self.dp_rank] |
| 127 | + |
| 128 | + all_hidden_states = self.dp_group.all_reduce(hidden_states) |
84 | 129 | hidden_states = all_hidden_states[start:end, :]
|
85 | 130 | return hidden_states
|
| 131 | + |
| 132 | + # def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 133 | + # if htorch.utils.internal.is_lazy(): |
| 134 | + # htorch.core.mark_step() |
| 135 | + |
| 136 | + # all_hidden_states = self.dp_group.all_reduce(hidden_states) |
| 137 | + # all_hidden_states = all_hidden_states.view(self.dp_group.world_size, -1, all_hidden_states.size(-1)) |
| 138 | + # hidden_states = all_hidden_states[self.dp_rank // self.dp_world_size, :, :] |
| 139 | + # return hidden_states |
0 commit comments