@@ -81,6 +81,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
81
81
def dispatch (
82
82
self , hidden_states : torch .Tensor ,
83
83
router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
84
+ assert hidden_states .dim () == 2 , "Input hidden states must be 2D"
84
85
input_size = hidden_states .size ()
85
86
# Allocate output tensor.
86
87
output_size = list (input_size )
@@ -108,13 +109,22 @@ def dispatch(
108
109
def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
109
110
if htorch .utils .internal .is_lazy ():
110
111
htorch .core .mark_step ()
112
+ assert hidden_states .dim () == 2 , "Input hidden states must be 2D"
111
113
cu_tokens_across_dp_cpu = get_forward_context (
112
114
).dp_metadata .cu_tokens_across_dp_cpu
113
115
114
- start = 0 if self . dp_rank == 0 else cu_tokens_across_dp_cpu [
115
- self . dp_rank - 1 ]
116
- end = cu_tokens_across_dp_cpu [self . dp_rank ]
116
+ # assume num tokens is padded across DP ranks
117
+ assert cu_tokens_across_dp_cpu [
118
+ 0 ] * self . dp_world_size == cu_tokens_across_dp_cpu [- 1 ]
117
119
118
- all_hidden_states = self .dp_group .all_reduce (hidden_states )
119
- hidden_states = all_hidden_states [start :end , :]
120
+ local_hidden_states = torch .empty (
121
+ (cu_tokens_across_dp_cpu [0 ], hidden_states .size (- 1 )),
122
+ device = hidden_states .device ,
123
+ dtype = hidden_states .dtype )
124
+
125
+ torch .distributed .reduce_scatter_tensor (
126
+ local_hidden_states ,
127
+ hidden_states ,
128
+ group = self .dp_group .device_group )
129
+ hidden_states = local_hidden_states
120
130
return hidden_states
0 commit comments