Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 28 additions & 56 deletions python/triton_kernels/bench/bench_mlp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pathlib import Path
from copy import deepcopy
import matplotlib.pyplot as plt
import json
import triton.profiler as proton
from triton.profiler import viewer
import torch
import triton_kernels
import triton_kernels.swiglu
Expand All @@ -21,35 +21,6 @@
cublas = None


def _query_gpu_specs():
import subprocess
if is_hip():
cmd = ["rocm-smi", "--showproductname", "-d=0", "--csv"]
output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
model = output.splitlines()[1].split(",")[2]
if model in ["0x74a9", "0x74a1"]:
name = "AMD Instinct MI300X"
elif model == "0x74a5":
name = "AMD Instinct MI325X"
else:
name = "AMD"
else:
cmd = ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader", "-i=0"]
output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
name = output.splitlines()[0]

gpu_specs = {
"NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35},
"NVIDIA GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0},
"AMD Instinct MI300X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 5.3},
"AMD Instinct MI325X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 6.0},
}
return gpu_specs.get(name)


SPECS = _query_gpu_specs()


def quantize(w, dtype, dev, **opt):
if dtype == "bf16":
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
Expand Down Expand Up @@ -80,6 +51,8 @@ class PerfData:
flops: float
bytes: float
bitwidth: int
device_type: str
device_info: dict

@property
def tflops(self):
Expand All @@ -95,15 +68,20 @@ def opint(self):
assert self.bytes > 0
return self.flops / self.bytes

@property
def max_tbps(self):
return proton.specs.max_bps(self.device_info["bus_width"], self.device_info["memory_clock_rate"]) * 1e-12

@property
def max_tflops(self):
return proton.specs.max_flops(self.device_type, self.device_info["arch"], self.bitwidth,
self.device_info["num_sms"], self.device_info["clock_rate"]) * 1e-12

@property
def util(self) -> float:
if SPECS is None:
return 0.0
assert self.bitwidth in (8, 16)

peak_flops = SPECS["MAX_TFLOPS8"] if self.bitwidth == 8 else SPECS["MAX_TFLOPS16"]
min_t_flop = self.flops / peak_flops * 1e-3 # ns → µs
min_t_bw = self.bytes / SPECS["MAX_TBPS"] * 1e-3
min_t_flop = self.flops / self.max_tflops * 1e-3
min_t_bw = self.bytes / self.max_tbps * 1e-3
return max(min_t_flop, min_t_bw) / self.time


Expand Down Expand Up @@ -171,21 +149,17 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
proton.finalize()

# -- analyze --
with open(f"{fpath}") as fd:
data = json.load(fd)
# TODO: this will be broken if kernels use scopes themselves
# compute useful (a.k.a. matmul) bytes and flops
matmuls = [
x for x in data[0]["children"] if "_matmul" in x["frame"]["name"] and "metadata" not in x["frame"]["name"]
]
bytes = sum([x["metrics"]["bytes"] for x in matmuls])
flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
flops = sum([flops[w] for w in [8, 16]])
# compute total time (incl. "not useful" work)
# TODO: proton should really be recording that in the json instead of
# relying on the user to aggregate
time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
return PerfData(time, flops, bytes, x_dtype.itemsize * 8)
gf, _, _, info = viewer.read(fpath)
# Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls
matmuls = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*matmul.*' AND c IS LEAF").dataframe
bytes = matmuls["bytes"].sum()
flops = sum(matmuls[[c for c in ["flops8", "flops16"] if c in matmuls.columns]].sum())
time = matmuls["time (ns)"].sum()
device_type = matmuls["device_type"].iloc[0]
device_id = matmuls["device_id"].iloc[0]
device_info = info[device_type][device_id]
return PerfData(time=time, flops=flops, bytes=bytes, bitwidth=x.dtype.itemsize * 8, device_type=device_type,
device_info=device_info)


def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",
Expand All @@ -204,6 +178,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
print(f"Batch: {batch}; Util: {perfs[-1].util}; TFLOPS: {perfs[-1].tflops}; TBPS: {perfs[-1].tbps}")
print("===============================================================")
# machine limits
max_tbps = perfs[0].max_tbps
max_tflops = perfs[0].max_tflops
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
ax.set_xlabel("batch size (toks/expt)")
ax.set_ylabel("performance [TFLOP/s]")
Expand All @@ -214,10 +190,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
xmin, xmax = min(xs), max(xs)
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
ax.set_xlim(xmin - dx, xmax + dx)
ax.set_ylim(100, SPECS["MAX_TFLOPS8"] + 500)
ax.set_ylim(100, max_tflops + 500)
# plot roofline
max_tbps = SPECS["MAX_TBPS"]
max_tflops = SPECS["MAX_TFLOPS8"]
opints = [p.opint for p in perfs]
knee = bisect_left(opints, max_tflops / max_tbps) - 1
x_bw, x_comp = xs[:knee], xs[knee:]
Expand All @@ -237,8 +211,6 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_

if __name__ == "__main__":
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
if SPECS is None:
print("Current GPU has no specs provided, utilization is N/A")
batch_ranges_dense = [(1024, 32768, 1024)]
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
dense_dtypes = ["fp8", "fp8"]
Expand Down
2 changes: 1 addition & 1 deletion python/triton_kernels/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "triton_kernels"
version = "1.0.0"
dependencies = ["torch", "numpy", "pytest"]
dependencies = ["torch", "numpy", "pytest", "llnl-hatchet"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dependency is causing problems internally because it depends on numpy 1. It appears that hatchet is not actually used inside triton_kernels here, so I think it should be possible to remove this dependency.


[build-system]
requires = ["setuptools>=64.0"]
Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/proton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
profile,
DEFAULT_PROFILE_NAME,
)
from . import context
from . import context, specs
56 changes: 56 additions & 0 deletions third_party/proton/proton/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
flops_by_device = {
"CUDA": {
"80":
lambda width, **kwargs: 624e12 / (width / 8),
"89":
lambda width, **kwargs: (330.3 * 1e12) / (width / 8), # TODO(Keren): Implement fp16 acc-> 660.6 fp8
"90":
lambda width, num_sms, clock_rate, **kwargs: ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) /
(width / 8),
"100":
lambda width, num_sms, clock_rate, **kwargs: (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8),
},
"HIP": {
"gfx90a": lambda width, **kwargs: 383e12 / (width / 8),
"gfx942": lambda width, **kwargs: 2614.9e12 / (width / 8),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antiagainst shall we drop a 0 or None value here for gfx950?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. The spec is unavailable right now. For mi300x and mi325x you can see the spec in https://github.com/triton-lang/triton/pull/6513/files#diff-5e6a8d3fc5ad9de85fc09ead926355cc19497d3056c368d463bf4c626ce68540. Can drop gfx941 here given that's deprecated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So only "gfx90a" and "gfx942" at this moment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also cc @ptillet for any additional comments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup--gfx90a for mi210/mi250, gfx942 for mi300/mi325. gfx940 and gfx941 was deprecated, see llvm/llvm-project#126763.

},
}


def max_flops(device_type, arch, width, num_sms, clock_rate):
"""
Calculate the maximum FLOPS for a given device type and width.

Args:
device_type (str): The type of device (e.g., "CUDA", "HIP").
arch (str): The architecture of the device (e.g., "80", "90").
width (int): The width in bits.
num_sms (int): The number of streaming multiprocessors.
clock_rate (float): The clock rate in GHz.

Returns:
float: The maximum FLOPS for the given device type and width.
"""
if device_type not in flops_by_device:
raise ValueError(f"Unsupported device type: {device_type}")

if arch not in flops_by_device[device_type]:
raise ValueError(f"Unsupported architecture: {arch}")

flops_func = flops_by_device[device_type][arch]

return flops_func(width, num_sms=num_sms, clock_rate=clock_rate)


def max_bps(bus_width, memory_clock_rate):
"""
Calculate the maximum bytes per second for a given bus width and memory clock rate.

Args:
bus_width (int): The bus width in bits.
memory_clock_rate (float): The memory clock rate in GHz.

Returns:
float: The maximum bytes per second.
"""
return 2 * bus_width * memory_clock_rate * 1e3 / 8
74 changes: 36 additions & 38 deletions third_party/proton/proton/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from collections import namedtuple
import json
import pandas as pd

try:
import hatchet as ht
from hatchet.query import NegationQuery
except ImportError:
raise ImportError("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version.")
import numpy as np
from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME, TritonHook
from triton.profiler import specs


def match_available_metrics(metrics, inclusive_metrics, exclusive_metrics):
Expand Down Expand Up @@ -83,25 +85,7 @@ def get_min_time_flops(df, device_info):
device_frames = df[idx]
if f"flops{width}" not in device_frames.columns:
continue
max_flops = 0
if device_type == "CUDA":
if arch == "80":
max_flops = 624e12 / (width / 8)
elif arch == "89":
# TODO(Keren): Implement fp16 acc-> 660.6 fp8
max_flops = (330.3 * 1e12) / (width / 8)
elif arch == "90":
# 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie
max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / (width / 8)
elif arch == "100":
max_flops = (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8)
elif device_type == "HIP":
if arch == "gfx90a":
max_flops = 383e12 / (width / 8)
elif arch == "gfx941" or arch == "gfx942":
max_flops = 2614.9e12 / (width / 8)
else:
raise ValueError(f"Unsupported device type: {device_type}")
max_flops = specs.max_flops(device_type, arch, width, num_sms, clock_rate)
min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops
return min_time_flops

Expand All @@ -114,7 +98,7 @@ def get_min_time_bytes(df, device_info):
device_frames = df[idx]
memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz
bus_width = device_info[device_type][device_index]["bus_width"] # in bits
peak_bandwidth = 2 * bus_width * memory_clock_rate * 1e3 / 8
peak_bandwidth = specs.max_bps(bus_width, memory_clock_rate)
min_time_bytes.loc[idx, "min_time"] += device_frames["bytes"] / peak_bandwidth
return min_time_bytes

Expand Down Expand Up @@ -150,7 +134,7 @@ def derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_inf

def get_time_seconds(df, metric, factor_dict):
time_metric_name = match_available_metrics(metric, inclusive_metrics, exclusive_metrics)[0]
time_unit = (factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0])
time_unit = factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0]
return df[time_metric_name] * factor_dict.factor[time_unit]

for metric in metrics:
Expand All @@ -171,13 +155,13 @@ def get_time_seconds(df, metric, factor_dict):
(get_time_seconds(gf.dataframe, "time", time_factor_dict)) /
metric_factor_dict[metric])
derived_metrics.append(f"{metric} (inc)")
elif metric in time_factor_dict.factor or metric in cpu_time_factor_dict.factor or \
metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor: # inclusive
elif (metric in time_factor_dict.factor or metric in cpu_time_factor_dict.factor
or metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor): # inclusive
is_cpu = metric in cpu_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor
is_avg = metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor

factor_dict = (avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict) if is_cpu \
else (avg_time_factor_dict if is_avg else time_factor_dict)
factor_dict = ((avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict) if is_cpu else
(avg_time_factor_dict if is_avg else time_factor_dict))
metric_name = "cpu_time" if is_cpu else "time"
metric_time_unit = factor_dict.name + "/" + metric.split("/")[1]

Expand Down Expand Up @@ -265,21 +249,26 @@ def print_tree(gf, metrics, depth=100, format=None, print_sorted=False):
print("Sorted kernels by metric " + metrics[0])
sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False)
for row in range(1, len(sorted_df)):
kernel_name = sorted_df.iloc[row]['name'][:100] + "..." if len(
sorted_df.iloc[row]['name']) > 100 else sorted_df.iloc[row]['name']
kernel_name = (sorted_df.iloc[row]["name"][:100] +
"..." if len(sorted_df.iloc[row]["name"]) > 100 else sorted_df.iloc[row]["name"])
print("{:105} {:.4}".format(kernel_name, sorted_df.iloc[row][metrics[0]]))
emit_warnings(gf, metrics)


def parse(metrics, filename, include=None, exclude=None, threshold=None):
def read(filename):
with open(filename, "r") as f:
gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f)
assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file"
gf.update_inclusive_columns()
metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info)
# TODO: generalize to support multiple metrics, not just the first one
gf = filter_frames(gf, include, exclude, threshold, metrics[0])
return gf, metrics
return gf, inclusive_metrics, exclusive_metrics, device_info


def parse(metrics, filename, include=None, exclude=None, threshold=None):
gf, inclusive_metrics, exclusive_metrics, device_info = read(filename)
metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info)
# TODO: generalize to support multiple metrics, not just the first one
gf = filter_frames(gf, include, exclude, threshold, metrics[0])
return gf, metrics


def show_metrics(file_name):
Expand Down Expand Up @@ -368,23 +357,32 @@ def main():
help="The depth of the tree to display",
)
argparser.add_argument(
"-f", "--format", type=str, choices=["full", "file_function_line", "function_line", "file_function"],
default="full", help="""Formatting the frame name.
"-f",
"--format",
type=str,
choices=["full", "file_function_line", "function_line", "file_function"],
default="full",
help="""Formatting the frame name.
- full: include the path, file name, function name and line number.
- file_function_line: include the file name, function name and line number.
- function_line: include the function name and line number.
- file_function: include the file name and function name.
""")
""",
)
argparser.add_argument(
"--print-sorted",
action='store_true',
action="store_true",
default=False,
help="Sort output by metric value instead of chronologically",
)
argparser.add_argument(
"--diff-profile", "-diff", type=str, default=None,
"--diff-profile",
"-diff",
type=str,
default=None,
help="Compare two profiles. When used as 'proton-viewer -m time -diff file1.log file2.log', "
"computes the difference: file2['time'] - file1['time']")
"computes the difference: file2['time'] - file1['time']",
)

args, target_args = argparser.parse_known_args()
assert len(target_args) == 1, "Must specify a file to read"
Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/test/examples/hip.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"num_sms": 104
},
"1": {
"arch": "gfx941",
"arch": "gfx942",
"bus_width": 8192,
"clock_rate": 5200000,
"memory_clock_rate": 2525000,
Expand Down
Loading
Loading