Skip to content

Commit cd62a76

Browse files
authored
[PROTON] Simplify proton viewer APIs for bench_mlp analysis (#6452)
1 parent 5cf16d7 commit cd62a76

File tree

7 files changed

+208
-195
lines changed

7 files changed

+208
-195
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pathlib import Path
22
from copy import deepcopy
33
import matplotlib.pyplot as plt
4-
import json
54
import triton.profiler as proton
5+
from triton.profiler import viewer
66
import torch
77
import triton_kernels
88
import triton_kernels.swiglu
@@ -21,35 +21,6 @@
2121
cublas = None
2222

2323

24-
def _query_gpu_specs():
25-
import subprocess
26-
if is_hip():
27-
cmd = ["rocm-smi", "--showproductname", "-d=0", "--csv"]
28-
output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
29-
model = output.splitlines()[1].split(",")[2]
30-
if model in ["0x74a9", "0x74a1"]:
31-
name = "AMD Instinct MI300X"
32-
elif model == "0x74a5":
33-
name = "AMD Instinct MI325X"
34-
else:
35-
name = "AMD"
36-
else:
37-
cmd = ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader", "-i=0"]
38-
output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
39-
name = output.splitlines()[0]
40-
41-
gpu_specs = {
42-
"NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35},
43-
"NVIDIA GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0},
44-
"AMD Instinct MI300X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 5.3},
45-
"AMD Instinct MI325X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 6.0},
46-
}
47-
return gpu_specs.get(name)
48-
49-
50-
SPECS = _query_gpu_specs()
51-
52-
5324
def quantize(w, dtype, dev, **opt):
5425
if dtype == "bf16":
5526
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
@@ -80,6 +51,8 @@ class PerfData:
8051
flops: float
8152
bytes: float
8253
bitwidth: int
54+
device_type: str
55+
device_info: dict
8356

8457
@property
8558
def tflops(self):
@@ -95,15 +68,20 @@ def opint(self):
9568
assert self.bytes > 0
9669
return self.flops / self.bytes
9770

71+
@property
72+
def max_tbps(self):
73+
return proton.specs.max_bps(self.device_info["bus_width"], self.device_info["memory_clock_rate"]) * 1e-12
74+
75+
@property
76+
def max_tflops(self):
77+
return proton.specs.max_flops(self.device_type, self.device_info["arch"], self.bitwidth,
78+
self.device_info["num_sms"], self.device_info["clock_rate"]) * 1e-12
79+
9880
@property
9981
def util(self) -> float:
100-
if SPECS is None:
101-
return 0.0
10282
assert self.bitwidth in (8, 16)
103-
104-
peak_flops = SPECS["MAX_TFLOPS8"] if self.bitwidth == 8 else SPECS["MAX_TFLOPS16"]
105-
min_t_flop = self.flops / peak_flops * 1e-3 # ns → µs
106-
min_t_bw = self.bytes / SPECS["MAX_TBPS"] * 1e-3
83+
min_t_flop = self.flops / self.max_tflops * 1e-3
84+
min_t_bw = self.bytes / self.max_tbps * 1e-3
10785
return max(min_t_flop, min_t_bw) / self.time
10886

10987

@@ -171,21 +149,17 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
171149
proton.finalize()
172150

173151
# -- analyze --
174-
with open(f"{fpath}") as fd:
175-
data = json.load(fd)
176-
# TODO: this will be broken if kernels use scopes themselves
177-
# compute useful (a.k.a. matmul) bytes and flops
178-
matmuls = [
179-
x for x in data[0]["children"] if "_matmul" in x["frame"]["name"] and "metadata" not in x["frame"]["name"]
180-
]
181-
bytes = sum([x["metrics"]["bytes"] for x in matmuls])
182-
flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
183-
flops = sum([flops[w] for w in [8, 16]])
184-
# compute total time (incl. "not useful" work)
185-
# TODO: proton should really be recording that in the json instead of
186-
# relying on the user to aggregate
187-
time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
188-
return PerfData(time, flops, bytes, x_dtype.itemsize * 8)
152+
gf, _, _, info = viewer.read(fpath)
153+
# Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls
154+
matmuls = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*matmul.*' AND c IS LEAF").dataframe
155+
bytes = matmuls["bytes"].sum()
156+
flops = sum(matmuls[[c for c in ["flops8", "flops16"] if c in matmuls.columns]].sum())
157+
time = matmuls["time (ns)"].sum()
158+
device_type = matmuls["device_type"].iloc[0]
159+
device_id = matmuls["device_id"].iloc[0]
160+
device_info = info[device_type][device_id]
161+
return PerfData(time=time, flops=flops, bytes=bytes, bitwidth=x.dtype.itemsize * 8, device_type=device_type,
162+
device_info=device_info)
189163

190164

191165
def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",
@@ -204,6 +178,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
204178
print(f"Batch: {batch}; Util: {perfs[-1].util}; TFLOPS: {perfs[-1].tflops}; TBPS: {perfs[-1].tbps}")
205179
print("===============================================================")
206180
# machine limits
181+
max_tbps = perfs[0].max_tbps
182+
max_tflops = perfs[0].max_tflops
207183
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
208184
ax.set_xlabel("batch size (toks/expt)")
209185
ax.set_ylabel("performance [TFLOP/s]")
@@ -214,10 +190,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
214190
xmin, xmax = min(xs), max(xs)
215191
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
216192
ax.set_xlim(xmin - dx, xmax + dx)
217-
ax.set_ylim(100, SPECS["MAX_TFLOPS8"] + 500)
193+
ax.set_ylim(100, max_tflops + 500)
218194
# plot roofline
219-
max_tbps = SPECS["MAX_TBPS"]
220-
max_tflops = SPECS["MAX_TFLOPS8"]
221195
opints = [p.opint for p in perfs]
222196
knee = bisect_left(opints, max_tflops / max_tbps) - 1
223197
x_bw, x_comp = xs[:knee], xs[knee:]
@@ -237,8 +211,6 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
237211

238212
if __name__ == "__main__":
239213
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
240-
if SPECS is None:
241-
print("Current GPU has no specs provided, utilization is N/A")
242214
batch_ranges_dense = [(1024, 32768, 1024)]
243215
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
244216
dense_dtypes = ["fp8", "fp8"]

python/triton_kernels/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "triton_kernels"
33
version = "1.0.0"
4-
dependencies = ["torch", "numpy", "pytest"]
4+
dependencies = ["torch", "numpy", "pytest", "llnl-hatchet"]
55

66
[build-system]
77
requires = ["setuptools>=64.0"]

third_party/proton/proton/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
profile,
1010
DEFAULT_PROFILE_NAME,
1111
)
12-
from . import context
12+
from . import context, specs

third_party/proton/proton/specs.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
flops_by_device = {
2+
"CUDA": {
3+
"80":
4+
lambda width, **kwargs: 624e12 / (width / 8),
5+
"89":
6+
lambda width, **kwargs: (330.3 * 1e12) / (width / 8), # TODO(Keren): Implement fp16 acc-> 660.6 fp8
7+
"90":
8+
lambda width, num_sms, clock_rate, **kwargs: ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) /
9+
(width / 8),
10+
"100":
11+
lambda width, num_sms, clock_rate, **kwargs: (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8),
12+
},
13+
"HIP": {
14+
"gfx90a": lambda width, **kwargs: 383e12 / (width / 8),
15+
"gfx942": lambda width, **kwargs: 2614.9e12 / (width / 8),
16+
},
17+
}
18+
19+
20+
def max_flops(device_type, arch, width, num_sms, clock_rate):
21+
"""
22+
Calculate the maximum FLOPS for a given device type and width.
23+
24+
Args:
25+
device_type (str): The type of device (e.g., "CUDA", "HIP").
26+
arch (str): The architecture of the device (e.g., "80", "90").
27+
width (int): The width in bits.
28+
num_sms (int): The number of streaming multiprocessors.
29+
clock_rate (float): The clock rate in GHz.
30+
31+
Returns:
32+
float: The maximum FLOPS for the given device type and width.
33+
"""
34+
if device_type not in flops_by_device:
35+
raise ValueError(f"Unsupported device type: {device_type}")
36+
37+
if arch not in flops_by_device[device_type]:
38+
raise ValueError(f"Unsupported architecture: {arch}")
39+
40+
flops_func = flops_by_device[device_type][arch]
41+
42+
return flops_func(width, num_sms=num_sms, clock_rate=clock_rate)
43+
44+
45+
def max_bps(bus_width, memory_clock_rate):
46+
"""
47+
Calculate the maximum bytes per second for a given bus width and memory clock rate.
48+
49+
Args:
50+
bus_width (int): The bus width in bits.
51+
memory_clock_rate (float): The memory clock rate in GHz.
52+
53+
Returns:
54+
float: The maximum bytes per second.
55+
"""
56+
return 2 * bus_width * memory_clock_rate * 1e3 / 8

third_party/proton/proton/viewer.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from collections import namedtuple
33
import json
44
import pandas as pd
5+
56
try:
67
import hatchet as ht
78
from hatchet.query import NegationQuery
89
except ImportError:
910
raise ImportError("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version.")
1011
import numpy as np
1112
from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME, TritonHook
13+
from triton.profiler import specs
1214

1315

1416
def match_available_metrics(metrics, inclusive_metrics, exclusive_metrics):
@@ -83,25 +85,7 @@ def get_min_time_flops(df, device_info):
8385
device_frames = df[idx]
8486
if f"flops{width}" not in device_frames.columns:
8587
continue
86-
max_flops = 0
87-
if device_type == "CUDA":
88-
if arch == "80":
89-
max_flops = 624e12 / (width / 8)
90-
elif arch == "89":
91-
# TODO(Keren): Implement fp16 acc-> 660.6 fp8
92-
max_flops = (330.3 * 1e12) / (width / 8)
93-
elif arch == "90":
94-
# 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie
95-
max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / (width / 8)
96-
elif arch == "100":
97-
max_flops = (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8)
98-
elif device_type == "HIP":
99-
if arch == "gfx90a":
100-
max_flops = 383e12 / (width / 8)
101-
elif arch == "gfx941" or arch == "gfx942":
102-
max_flops = 2614.9e12 / (width / 8)
103-
else:
104-
raise ValueError(f"Unsupported device type: {device_type}")
88+
max_flops = specs.max_flops(device_type, arch, width, num_sms, clock_rate)
10589
min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops
10690
return min_time_flops
10791

@@ -114,7 +98,7 @@ def get_min_time_bytes(df, device_info):
11498
device_frames = df[idx]
11599
memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz
116100
bus_width = device_info[device_type][device_index]["bus_width"] # in bits
117-
peak_bandwidth = 2 * bus_width * memory_clock_rate * 1e3 / 8
101+
peak_bandwidth = specs.max_bps(bus_width, memory_clock_rate)
118102
min_time_bytes.loc[idx, "min_time"] += device_frames["bytes"] / peak_bandwidth
119103
return min_time_bytes
120104

@@ -150,7 +134,7 @@ def derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_inf
150134

151135
def get_time_seconds(df, metric, factor_dict):
152136
time_metric_name = match_available_metrics(metric, inclusive_metrics, exclusive_metrics)[0]
153-
time_unit = (factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0])
137+
time_unit = factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0]
154138
return df[time_metric_name] * factor_dict.factor[time_unit]
155139

156140
for metric in metrics:
@@ -171,13 +155,13 @@ def get_time_seconds(df, metric, factor_dict):
171155
(get_time_seconds(gf.dataframe, "time", time_factor_dict)) /
172156
metric_factor_dict[metric])
173157
derived_metrics.append(f"{metric} (inc)")
174-
elif metric in time_factor_dict.factor or metric in cpu_time_factor_dict.factor or \
175-
metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor: # inclusive
158+
elif (metric in time_factor_dict.factor or metric in cpu_time_factor_dict.factor
159+
or metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor): # inclusive
176160
is_cpu = metric in cpu_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor
177161
is_avg = metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor
178162

179-
factor_dict = (avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict) if is_cpu \
180-
else (avg_time_factor_dict if is_avg else time_factor_dict)
163+
factor_dict = ((avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict) if is_cpu else
164+
(avg_time_factor_dict if is_avg else time_factor_dict))
181165
metric_name = "cpu_time" if is_cpu else "time"
182166
metric_time_unit = factor_dict.name + "/" + metric.split("/")[1]
183167

@@ -265,21 +249,26 @@ def print_tree(gf, metrics, depth=100, format=None, print_sorted=False):
265249
print("Sorted kernels by metric " + metrics[0])
266250
sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False)
267251
for row in range(1, len(sorted_df)):
268-
kernel_name = sorted_df.iloc[row]['name'][:100] + "..." if len(
269-
sorted_df.iloc[row]['name']) > 100 else sorted_df.iloc[row]['name']
252+
kernel_name = (sorted_df.iloc[row]["name"][:100] +
253+
"..." if len(sorted_df.iloc[row]["name"]) > 100 else sorted_df.iloc[row]["name"])
270254
print("{:105} {:.4}".format(kernel_name, sorted_df.iloc[row][metrics[0]]))
271255
emit_warnings(gf, metrics)
272256

273257

274-
def parse(metrics, filename, include=None, exclude=None, threshold=None):
258+
def read(filename):
275259
with open(filename, "r") as f:
276260
gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f)
277261
assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file"
278262
gf.update_inclusive_columns()
279-
metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info)
280-
# TODO: generalize to support multiple metrics, not just the first one
281-
gf = filter_frames(gf, include, exclude, threshold, metrics[0])
282-
return gf, metrics
263+
return gf, inclusive_metrics, exclusive_metrics, device_info
264+
265+
266+
def parse(metrics, filename, include=None, exclude=None, threshold=None):
267+
gf, inclusive_metrics, exclusive_metrics, device_info = read(filename)
268+
metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info)
269+
# TODO: generalize to support multiple metrics, not just the first one
270+
gf = filter_frames(gf, include, exclude, threshold, metrics[0])
271+
return gf, metrics
283272

284273

285274
def show_metrics(file_name):
@@ -368,23 +357,32 @@ def main():
368357
help="The depth of the tree to display",
369358
)
370359
argparser.add_argument(
371-
"-f", "--format", type=str, choices=["full", "file_function_line", "function_line", "file_function"],
372-
default="full", help="""Formatting the frame name.
360+
"-f",
361+
"--format",
362+
type=str,
363+
choices=["full", "file_function_line", "function_line", "file_function"],
364+
default="full",
365+
help="""Formatting the frame name.
373366
- full: include the path, file name, function name and line number.
374367
- file_function_line: include the file name, function name and line number.
375368
- function_line: include the function name and line number.
376369
- file_function: include the file name and function name.
377-
""")
370+
""",
371+
)
378372
argparser.add_argument(
379373
"--print-sorted",
380-
action='store_true',
374+
action="store_true",
381375
default=False,
382376
help="Sort output by metric value instead of chronologically",
383377
)
384378
argparser.add_argument(
385-
"--diff-profile", "-diff", type=str, default=None,
379+
"--diff-profile",
380+
"-diff",
381+
type=str,
382+
default=None,
386383
help="Compare two profiles. When used as 'proton-viewer -m time -diff file1.log file2.log', "
387-
"computes the difference: file2['time'] - file1['time']")
384+
"computes the difference: file2['time'] - file1['time']",
385+
)
388386

389387
args, target_args = argparser.parse_known_args()
390388
assert len(target_args) == 1, "Must specify a file to read"

third_party/proton/test/examples/hip.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"num_sms": 104
5454
},
5555
"1": {
56-
"arch": "gfx941",
56+
"arch": "gfx942",
5757
"bus_width": 8192,
5858
"clock_rate": 5200000,
5959
"memory_clock_rate": 2525000,

0 commit comments

Comments
 (0)