Skip to content

Commit beacdb0

Browse files
authored
[BENCH] fast top-k routing and matmul metadata preparation (#6507)
big thanks to @apgoucher for all the help!
1 parent 5fdff50 commit beacdb0

File tree

9 files changed

+474
-103
lines changed

9 files changed

+474
-103
lines changed

bench/bench/bench_mlp.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from triton_bench.mxfp import downcast_to_mxfp
88
from triton_bench.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
99
from triton_bench.numerics import InFlexData
10-
from triton_bench.routing import routing_torch, simulate_expert_sharded_routing
10+
from triton_bench.routing import routing, simulate_expert_sharded_routing
1111
from triton_bench.meta import cuda_capability_geq
1212

1313

@@ -96,17 +96,19 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
9696
for i in range(100):
9797
x = torch.randn((batch, dim1), device=dev)
9898
x = x.to(wg.dtype if n_expts_tot > 1 else x_dtype)
99-
# TODO: activate proton here when fast routing is done
99+
proton.activate()
100100
if n_expts_tot > 1:
101101
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
102-
rdata, gather_indx, scatter_indx = routing_torch(logits, n_expts_act)
102+
rdata, gather_indx, scatter_indx = routing(logits, n_expts_act)
103103
if EP > 1:
104+
proton.deactivate()
105+
# TODO: activate proton here when fast expert parallelism simulation is done
104106
m = logits.shape[0] * EP
105107
_, rdata, gather_indx, scatter_indx = simulate_expert_sharded_routing(m, rdata, EP, device=dev)
108+
proton.activate()
106109
x = x.to(x_dtype)
107110
else:
108111
rdata, gather_indx, scatter_indx = None, None, None
109-
proton.activate()
110112
# c0 = torch.empty((x.shape[0], w1.shape[-1]), device=dev, dtype=x.dtype)
111113
# c1 = torch.empty((x.shape[0], w2.shape[-1]), device=dev, dtype=x.dtype)
112114
# cublas.matmul(x, w1.squeeze(0), c0)
@@ -146,5 +148,5 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
146148
qxdtype = "fp8" if has_native_mx4 else "bf16"
147149
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
148150
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
149-
print(bench_mlp(1024, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=2, name="llama4"))
150-
print(bench_mlp(1024, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=2, name="llama4"))
151+
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4"))
152+
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=1, name="llama4"))

bench/tests/test_routing.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import pytest
2+
import torch
3+
from triton_bench.routing import routing, routing_torch
4+
from triton_bench.testing import assert_close
5+
from triton_bench.matmul_ogs_details.metadata import compute_metadata
6+
from triton_bench.testing import assert_equal
7+
8+
9+
def init_data(n_tokens, n_expts_tot, dtype=torch.float16):
10+
dev = "cuda"
11+
# the reference implementation and the triton implementation do not tie-break experts the same way
12+
randbits = [torch.randperm(n_expts_tot) for _ in range(n_tokens)]
13+
x = [(-1)**i * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(dtype)) for i, bits in enumerate(randbits)]
14+
return torch.stack(x).to(device=dev)
15+
16+
17+
def ref_expt_data(routing_data, n_gates, block_m):
18+
hist = routing_data.expt_hist
19+
n_expts_tot = routing_data.n_expts_tot
20+
blks = (hist + block_m - 1) // block_m # matmul blocks needed
21+
tsum = torch.cumsum(hist, dim=0) # prefix sum of tokens
22+
bsum = torch.cumsum(blks, dim=0) # prefix sum of blocks
23+
# Get the max number of matmul blocks of size d_tile needed (and is launched with).
24+
# This assumes the worst distribution of all experts with one token except for one that has the rest.
25+
if n_gates <= n_expts_tot:
26+
grid_m = n_gates
27+
else:
28+
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
29+
# ceil_div(x, y): -(-x // y)
30+
grid_m = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m)
31+
bloc_data = -torch.ones(grid_m, dtype=torch.int32)
32+
# compute data required to drive ragged batch matmul
33+
for e in range(n_expts_tot):
34+
offset = bsum[e - 1] if e else 0
35+
for b in range(blks[e]):
36+
bloc_data[offset + b] = (b << 16) + e
37+
38+
expt_data = torch.zeros(n_expts_tot * 3 + 2 + grid_m, dtype=torch.int32, device=hist.device)
39+
expt_data[:n_expts_tot] = routing_data.expt_hist
40+
expt_data[n_expts_tot + 1:n_expts_tot * 2 + 1] = tsum
41+
expt_data[n_expts_tot * 2 + 2:n_expts_tot * 3 + 2] = bsum
42+
expt_data[n_expts_tot * 3 + 2:] = bloc_data
43+
return expt_data
44+
45+
46+
@pytest.mark.parametrize("n_tokens", [371, 255, 256, 8192, 1023, 1024])
47+
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4)])
48+
@pytest.mark.parametrize("block_m", [64, 128])
49+
def test_op(n_tokens, n_expts_tot, n_expts_act, block_m):
50+
torch.manual_seed(2)
51+
tri_logits = init_data(n_tokens, n_expts_tot).detach()
52+
ref_logits = tri_logits.clone()
53+
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act)
54+
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)
55+
ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m)
56+
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m).buffer
57+
58+
assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3)
59+
assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist)
60+
assert_equal(ref_metadata, tri_metadata)
61+
assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot
62+
assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act
63+
64+
def _assert_indx_equal(ref, tri):
65+
assert_equal(ref, tri[:len(ref)])
66+
assert torch.all(tri[len(ref):] == -1)
67+
68+
_assert_indx_equal(ref_gather.src_indx, tri_gather.src_indx)
69+
_assert_indx_equal(ref_gather.dst_indx, tri_gather.dst_indx)
70+
_assert_indx_equal(ref_scatter.src_indx, tri_scatter.src_indx)
71+
_assert_indx_equal(ref_scatter.dst_indx, tri_scatter.dst_indx)
72+
73+
74+
def bench_routing():
75+
import triton.profiler as proton
76+
n_tokens = 2048
77+
block_m = 128
78+
n_expts_tot, n_expts_act = 128, 4
79+
tri_logits = init_data(n_tokens, n_expts_tot)
80+
proton.start("routing")
81+
proton.activate()
82+
for i in range(100):
83+
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)
84+
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
85+
proton.finalize()
86+
87+
88+
if __name__ == "__main__":
89+
bench_routing()

bench/triton_bench/matmul_ogs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1818
from .matmul_ogs_details.opt_flags import make_opt_flags
19+
from .matmul_ogs_details.metadata import compute_metadata
1920

2021
# -----------------------------------------------------------------------------
2122
# Matrix Multiplication + Outer Gather/Scatter
@@ -243,7 +244,8 @@ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data,
243244
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
244245
# preprocess routing information and ptr lookup table
245246
M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0]
246-
expt_data = routing_data.expt_data(M, opt_flags.block_m)
247+
# compute expt_data
248+
expt_data = compute_metadata(routing_data, M, opt_flags.block_m)
247249
return x, w, preprocessing_features.swap_xw, writeback_idxs, writeback_size, expt_data
248250

249251
# ---------------------
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from dataclasses import dataclass
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@dataclass
8+
class ExptData:
9+
hist: torch.Tensor
10+
offs: torch.Tensor
11+
offs_sum: torch.Tensor
12+
blocks: torch.Tensor
13+
buffer: torch.Tensor
14+
15+
16+
@triton.jit
17+
def _memset_metadata(Metadata, metadata_size, BLOCK: tl.constexpr):
18+
pid = tl.program_id(0)
19+
offs = pid * BLOCK + tl.arange(0, BLOCK)
20+
tl.store(Metadata + offs, 0xffffffff, mask=offs < metadata_size)
21+
22+
23+
@triton.jit
24+
def _compute_metadata(Hist, n_expts_tot, MDHist, MDTokStarts, MDTileStarts, MDTileInfo, N_EXPTS_PAD: tl.constexpr,
25+
BLOCK: tl.constexpr, TILE_DIM: tl.constexpr):
26+
expt_id = tl.program_id(0)
27+
n_tokens = tl.load(Hist + expt_id)
28+
n_blocks = tl.cdiv(n_tokens, TILE_DIM)
29+
offs_n = tl.arange(0, N_EXPTS_PAD)
30+
mask = offs_n < n_expts_tot
31+
hist = tl.load(Hist + offs_n, mask=mask)
32+
tile_starts = tl.cumsum(tl.cdiv(hist, TILE_DIM), 0)
33+
# first pid to reach this initializes histograms and cumsums
34+
if expt_id == 0:
35+
tok_starts = tl.cumsum(hist, 0)
36+
tl.store(MDHist + offs_n, hist, mask=mask)
37+
tl.store(MDTokStarts, 0)
38+
tl.store(MDTokStarts + 1 + offs_n, tok_starts, mask=mask)
39+
tl.store(MDTileStarts, 0)
40+
tl.store(MDTileStarts + 1 + offs_n, tile_starts, mask=mask)
41+
tile_off = tl.sum(tl.where(offs_n == expt_id - 1, tile_starts, 0), 0)
42+
MDTileInfo += tile_off
43+
# MDTileInfo += tl.load(MDTilesStart + expt_id)
44+
for block_off in range(0, n_blocks, BLOCK):
45+
block_offs = block_off + tl.arange(0, BLOCK)
46+
data = (block_offs << 16) + expt_id
47+
tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
48+
49+
50+
def compute_metadata(routing_data, n_rows, block_m):
51+
if routing_data.expt_hist is None:
52+
return ExptData(None, None, None, None, None)
53+
MEMSET_BLOCK = 512
54+
HIST2_BLOCK_M = 512
55+
device = routing_data.expt_hist.device
56+
n_expts_tot = routing_data.n_expts_tot
57+
cdiv = triton.cdiv
58+
if n_rows <= n_expts_tot:
59+
grid_m = n_rows
60+
else:
61+
grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1) // block_m)
62+
n_expts_pad = cdiv(n_expts_tot, 128) * 128
63+
metadata_size = 3 * n_expts_tot + 2 + grid_m
64+
metadata = torch.empty(metadata_size, dtype=torch.int32, device=device)
65+
md_hist = metadata[:n_expts_tot]
66+
md_tok_starts = metadata[n_expts_tot:n_expts_tot * 2 + 1]
67+
md_tile_starts = metadata[n_expts_tot * 2 + 1:n_expts_tot * 3 + 2]
68+
md_tile_infos = metadata[n_expts_tot * 3 + 2:]
69+
_memset_metadata[(cdiv(metadata_size, MEMSET_BLOCK), )](
70+
metadata, metadata_size, # inputs
71+
BLOCK=MEMSET_BLOCK # optimization parameters
72+
)
73+
_compute_metadata[(n_expts_tot, )](
74+
routing_data.expt_hist, n_expts_tot, # inputs
75+
md_hist, md_tok_starts, md_tile_starts, md_tile_infos, # outputs
76+
BLOCK=HIST2_BLOCK_M, # optimization parameters
77+
N_EXPTS_PAD=n_expts_pad, TILE_DIM=block_m, # constants
78+
)
79+
hist = metadata[:n_expts_tot]
80+
offs = metadata[n_expts_tot:2 * n_expts_tot + 1]
81+
offs_sum = metadata[3 * n_expts_tot + 2 - 1]
82+
blocks = metadata[n_expts_tot + 2 * (n_expts_tot + 1):]
83+
return ExptData(hist, offs, offs_sum, blocks, metadata)

0 commit comments

Comments
 (0)