Skip to content

Commit 3ac3994

Browse files
JokerenThomasRaoux
andauthored
[BENCH] Fix distributed tests (#8937)
@ThomasRaoux will try to enable the CI for testing distributed kernels --------- Co-authored-by: Thomas Raoux <[email protected]>
1 parent 3530aab commit 3ac3994

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

python/triton_kernels/bench/distributed.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,24 +114,24 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor:
114114
def reduce_scatter(
115115
input_tensor: torch.Tensor,
116116
n_expts_act: int,
117-
metadata: ReduceScatterMetadata,
117+
metadata: Optional[ReduceScatterMetadata] = None,
118118
expt_assignment: Optional[ExptAssignment] = None,
119119
dim: int = 0,
120120
op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM,
121121
) -> torch.Tensor:
122-
if _is_distributed_launch():
123-
if metadata.mode and metadata.mode == "ep_sharding":
122+
if metadata and _is_distributed_launch():
123+
if metadata.mode == "ep_sharding":
124124
if dim != 0 or op != dist.ReduceOp.SUM:
125125
raise NotImplementedError("Only dim=0 and op=SUM are supported for MoE reduce_scatter.")
126126
output = convert_ep_to_dp(input_tensor, expt_assignment, metadata.active_indx, metadata.combine_indx)
127-
# weighted average of the output token from experts
128-
output = output.view(-1, n_expts_act, output.shape[-1])
129-
output, _ = reduce(output, dim=1)
130-
return output
131127
else:
132128
raise NotImplementedError(f"Distributed reduce_scatter mode {metadata.mode} is not implemented yet.")
133129
else:
134-
return input_tensor
130+
output = input_tensor
131+
# weighted average of the output token from experts
132+
output = output.view(-1, n_expts_act, output.shape[-1])
133+
output, _ = reduce(output, dim=1)
134+
return output
135135

136136

137137
# TODO: support TP > 1
@@ -283,7 +283,8 @@ def single(x):
283283
else:
284284
rdata = gi = si = None
285285
x = matmul(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act)
286-
return matmul(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full)
286+
x = matmul(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full)
287+
return reduce_scatter(x, n_expts_act, metadata=None, expt_assignment=None)
287288

288289
# distributed pass
289290
def distributed(x):
@@ -328,17 +329,17 @@ def distributed(x):
328329
+ [
329330
(128, 1024, 1024, 128, 2, "bf16", "bf16", 1, 1),
330331
(1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 1),
331-
(1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 4),
332+
(1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 2),
332333
] +
333334
# moe cases - test precision
334335
([
335336
(128, 1024, 1024, 128, 2, "fp8", "mx4", 1, 1),
336337
(1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 1),
337-
(1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 4),
338+
(1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 2),
338339
] if has_native_mx4 else [
339340
(128, 1024, 1024, 128, 2, "bf16", "mx4", 1, 1),
340341
(1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 1),
341-
(1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 4),
342+
(1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 2),
342343
]),
343344
)
344345
def test_mlp_mp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, monkeypatch):

0 commit comments

Comments
 (0)