Skip to content

Commit 5d17897

Browse files
authored
[BENCH][KERNELS] Extract common code from bench_mlp.py and distributed.py into bench_utils.py (#8866)
Previously, updates to `triton_kernels` only triggered tests in `distributed.py`, leaving `bench_mlp.py` untested and causing mismatches . This change ensures shared data structures or configs only appear once so updates to `triton_kernels` are consistently validated. Also fixed a few typing issues
1 parent a48e358 commit 5d17897

File tree

5 files changed

+101
-97
lines changed

5 files changed

+101
-97
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from itertools import chain
22
from pathlib import Path
3-
from copy import deepcopy
43
import triton.profiler as proton
54
import torch
65
import argparse
7-
import triton_kernels
86
import triton_kernels.roofline as roofline
9-
import triton_kernels.swiglu
10-
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
11-
from triton_kernels.target_info import get_cdna_version, cuda_capability_geq
7+
from triton_kernels.matmul import matmul
8+
from triton_kernels.target_info import get_cdna_version
129
import distributed as triton_dist
13-
from triton_kernels.tensor_details import layout
14-
from bench_utils import quantize_weight
10+
from bench_utils import prepare_mlp_numerics, resolve_x_dtype
1511
import tempfile
1612

1713

@@ -40,35 +36,12 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
4036
b2 = triton_dist.broadcast(b2, src=ep_indx * TP, groups=groups, group_idx=ep_indx)
4137

4238
# -- numerics --
43-
opt1 = dict()
44-
opt2 = dict()
45-
if w_dtype == "mx4":
46-
# on hopper we only use 8 warps when weight is scaled
47-
num_warps = 4 if batch <= 512 and cuda_capability_geq(10, 0) else 8
48-
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
49-
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
50-
mx_axis=1, num_warps=num_warps)
51-
opt1 = {
52-
"value_layout": value_layout,
53-
"value_layout_opts": value_layout_opts,
54-
"scale_layout": scale_layout,
55-
"scale_layout_opts": scale_layout_opts,
56-
}
57-
opt2 = deepcopy(opt1)
58-
wg, wg_flex, wg_scale = quantize_weight(wg, "bf16")
59-
w1, w1_flex, w1_scale = quantize_weight(w1, w_dtype, **opt1)
60-
w2, w2_flex, w2_scale = quantize_weight(w2, w_dtype, **opt2)
61-
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale)
62-
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2),
63-
(1.0, 1.0))
64-
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale)
65-
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale)
39+
numerics = prepare_mlp_numerics(batch, w_dtype, wg, w1, w2)
40+
wg, w1, w2 = numerics.wg, numerics.w1, numerics.w2
41+
pcg, pc1, pc2, act = numerics.pcg, numerics.pc1, numerics.pc2, numerics.activation
6642

6743
# -- benchmark --
68-
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
69-
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
70-
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
71-
x_dtype = torch.float8_e4m3fnuz
44+
x_dtype = resolve_x_dtype(x_dtype)
7245

7346
input_x = torch.randn((batch // DP, dim1), device=dev)
7447
expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev))

python/triton_kernels/bench/bench_utils.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
from copy import deepcopy
2+
from dataclasses import dataclass
3+
4+
import triton_kernels
5+
import triton_kernels.swiglu
6+
from triton_kernels.matmul import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
17
from triton_kernels.numerics import InFlexData
28
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
3-
from triton_kernels.tensor import convert_layout
4-
from triton_kernels.tensor import wrap_torch_tensor, FP4
5-
from triton_kernels.target_info import is_cuda, get_cdna_version, cuda_capability_geq
9+
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4, Tensor
10+
from triton_kernels.target_info import is_cuda, get_cdna_version, cuda_capability_geq, is_hip
11+
from triton_kernels.tensor_details import layout
612
import torch
713

814

9-
def quantize_weight(w, dtype, **opt):
15+
def _quantize_weight(w, dtype, **opt):
1016
if dtype == "bf16":
1117
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
1218
return wq, InFlexData(), None
@@ -23,3 +29,60 @@ def quantize_weight(w, dtype, **opt):
2329
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
2430
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
2531
return w, InFlexData(), w_scale
32+
33+
34+
@dataclass
35+
class MlpNumerics:
36+
wg: torch.Tensor | Tensor | None
37+
w1: torch.Tensor | Tensor | None
38+
w2: torch.Tensor | Tensor | None
39+
pcg: PrecisionConfig
40+
pc1: PrecisionConfig
41+
pc2: PrecisionConfig
42+
activation: FusedActivation
43+
44+
45+
def _make_default_mlp_activation() -> FusedActivation:
46+
return FusedActivation(
47+
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2),
48+
(1.0, 1.0),
49+
)
50+
51+
52+
def _make_mx4_quantization_opts(batch: int, w_dtype: str) -> dict:
53+
if w_dtype != "mx4" or is_hip():
54+
return {}
55+
num_warps = 4 if batch <= 512 and cuda_capability_geq(10, 0) else 8
56+
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
57+
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
58+
return {
59+
"value_layout": value_layout,
60+
"value_layout_opts": value_layout_opts,
61+
"scale_layout": scale_layout,
62+
"scale_layout_opts": scale_layout_opts,
63+
}
64+
65+
66+
def prepare_mlp_numerics(batch: int, w_dtype: str, wg, w1, w2) -> MlpNumerics:
67+
quantization_opts = _make_mx4_quantization_opts(batch, w_dtype)
68+
wg, wg_flex, wg_scale = _quantize_weight(wg, "bf16")
69+
w1, w1_flex, w1_scale = _quantize_weight(w1, w_dtype, **deepcopy(quantization_opts))
70+
w2, w2_flex, w2_scale = _quantize_weight(w2, w_dtype, **deepcopy(quantization_opts))
71+
activation = _make_default_mlp_activation()
72+
return MlpNumerics(
73+
wg=wg,
74+
w1=w1,
75+
w2=w2,
76+
pcg=PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale),
77+
pc1=PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale),
78+
pc2=PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale),
79+
activation=activation,
80+
)
81+
82+
83+
def resolve_x_dtype(x_dtype: str) -> torch.dtype:
84+
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}
85+
dtype = dtype_map[x_dtype]
86+
if dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
87+
return torch.float8_e4m3fnuz
88+
return dtype

python/triton_kernels/bench/distributed.py

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,17 @@
33
import torch
44
import torch.distributed as dist
55
import torch.multiprocessing as mp
6-
from copy import deepcopy
76
from dataclasses import dataclass
87
from typing import Tuple, Optional
98

10-
import triton_kernels
11-
import triton_kernels.swiglu
129
from triton_kernels.reduce import reduce
1310
from triton_kernels.topk import topk
14-
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
11+
from triton_kernels.matmul import matmul
1512
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq
16-
from triton_kernels.tensor_details import layout
1713
from triton_kernels.tensor import RaggedTensorMetadata, make_ragged_tensor_metadata, remap_ragged_tensor_metadata
1814
from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment, symm_mem_pool
1915

20-
from bench_utils import quantize_weight
16+
from bench_utils import prepare_mlp_numerics, resolve_x_dtype
2117

2218

2319
@dataclass
@@ -250,50 +246,20 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
250246
b1_full = gather_full(rank, world_size, b1, TP, EP, concat_dim_inside=1, concat_dim_outside=0)
251247
b2_full = gather_ep(rank, world_size, b2, TP, EP)
252248

253-
# quantization
254-
opt1 = dict()
255-
opt2 = dict()
256-
if w_dtype == "mx4" and not is_hip():
257-
# on hopper we only use 8 warps when weight is scaled
258-
num_warps = 4 if batch <= 512 and cuda_capability_geq(10, 0) else 8
259-
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
260-
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
261-
mx_axis=1, num_warps=num_warps)
262-
opt1 = {
263-
"value_layout": value_layout,
264-
"value_layout_opts": value_layout_opts,
265-
"scale_layout": scale_layout,
266-
"scale_layout_opts": scale_layout_opts,
267-
}
268-
opt2 = deepcopy(opt1)
269-
wg, wg_flex, wg_scale = quantize_weight(wg, "bf16")
270-
w1, w1_flex, w1_scale = quantize_weight(w1, w_dtype, **opt1)
271-
w2, w2_flex, w2_scale = quantize_weight(w2, w_dtype, **opt2)
249+
wg_unquantized = wg
250+
numerics = prepare_mlp_numerics(batch, w_dtype, wg_unquantized, w1, w2)
251+
wg, w1, w2 = numerics.wg, numerics.w1, numerics.w2
252+
pcg, pc1, pc2, act = numerics.pcg, numerics.pc1, numerics.pc2, numerics.activation
272253
if rank == 0:
273-
w1_full, w1_flex_full, w1_scale_full = quantize_weight(w1_full, w_dtype, **opt1)
274-
w2_full, w2_flex_full, w2_scale_full = quantize_weight(w2_full, w_dtype, **opt2)
275-
else:
276-
w1_full = w2_full = w1_flex_full = w2_flex_full = w1_scale_full = w2_scale_full = None
277-
278-
# precision configs
279-
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale)
280-
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2),
281-
(1.0, 1.0))
282-
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale)
283-
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale)
284-
if rank == 0:
285-
pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), b_mx_scale=w1_scale_full)
286-
pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), b_mx_scale=w2_scale_full)
254+
full_numerics = prepare_mlp_numerics(batch, w_dtype, wg_unquantized, w1_full, w2_full)
255+
w1_full, w2_full = full_numerics.w1, full_numerics.w2
256+
pc1_full, pc2_full = full_numerics.pc1, full_numerics.pc2
287257
else:
288258
pc1_full = pc2_full = None
289259

290260
# inputs
291-
dtype_map = {
292-
"fp16": torch.float16,
293-
"bf16": torch.bfloat16,
294-
"fp8": torch.float8_e4m3fnuz if get_cdna_version() == 3 else torch.float8_e4m3fn,
295-
}
296-
xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype])
261+
input_dtype = resolve_x_dtype(x_dtype)
262+
xd = torch.randn((batch // world_size, dim1), device=dev).to(input_dtype)
297263
x0 = all_gather(xd, dim=0)
298264
expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev))
299265
symm_mem_pool.initialize_matmul(

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import triton
77
from enum import Enum, auto
88
import math
9+
from typing import Callable
910
# utilities
1011
from triton_kernels import target_info
1112
from triton_kernels.numerics import InFlexData, OutFlexData
@@ -26,15 +27,15 @@
2627
@dataclass(frozen=True)
2728
class FusedActivation:
2829
specs: FnSpecs = FnSpecs.default()
29-
fn_args: tuple[object] = tuple()
30+
fn_args: tuple[object, ...] = tuple()
3031

3132

3233
@dataclass(frozen=True)
3334
class Epilogue:
3435
specs: FnSpecs = FnSpecs.default()
35-
fn_arg_values_matmul: tuple[object] = tuple()
36-
fn_arg_values_finalize: tuple[object] = tuple()
37-
effective_itemsize: float = None
36+
fn_arg_values_matmul: tuple[object, ...] = tuple()
37+
fn_arg_values_finalize: tuple[object, ...] = tuple()
38+
effective_itemsize: float | None = None
3839

3940
class FnName(Enum):
4041
QUANTIZE_MXFP8 = auto()
@@ -86,16 +87,16 @@ class FlexCtx:
8687

8788
@dataclass
8889
class PrecisionConfig:
89-
max_num_imprecise_acc: int = None
90+
max_num_imprecise_acc: int | None = None
9091
allow_tf32: bool = True
9192
flex_ctx: FlexCtx = FlexCtx()
92-
acc_scale: int = 1.0
93+
acc_scale: float = 1.0
9394
flexpoint_saturate_inf: bool = False
94-
report_quantization_err_fn: callable = None
95-
a_mx_scale: Tensor | None = None
96-
b_mx_scale: Tensor| None = None
97-
c_mx_scale: Tensor | None = None
98-
out_dtype: torch.dtype = None
95+
report_quantization_err_fn: Callable | None = None
96+
a_mx_scale: torch.Tensor | Tensor | None = None
97+
b_mx_scale: torch.Tensor | Tensor | None = None
98+
c_mx_scale: torch.Tensor | Tensor | None = None
99+
out_dtype: torch.dtype | None = None
99100
enforce_bitwise_invariance: bool = False
100101

101102

python/triton_kernels/triton_kernels/specialize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23
import inspect
34
import re
45
import textwrap
@@ -66,9 +67,9 @@ def _empty_fn():
6667
@dataclass(frozen=True)
6768
class FnSpecs:
6869
name: str
69-
fn: "triton.runtime.jit.JITFunction"
70-
fn_arg_names: tuple[str]
71-
fn_arg_do_not_specialize: tuple[str] = tuple()
70+
fn: Optional["triton.runtime.jit.JITFunction"]
71+
fn_arg_names: tuple[str, ...] = tuple()
72+
fn_arg_do_not_specialize: tuple[str, ...] = tuple()
7273
reduction_n: int = 1
7374

7475
@staticmethod

0 commit comments

Comments
 (0)