Skip to content

Commit 2481c4b

Browse files
henrylhtsangpytorchmergebot
authored andcommitted
[cutlass backend] add teraflops and increase rep for benchmark script (pytorch#154944)
Differential Revision: [D75840023](https://our.internmc.facebook.com/intern/diff/D75840023/) I think I will continue to use do_bench for now. Pull Request resolved: pytorch#154944 Approved by: https://github.com/mlazos
1 parent be2ab96 commit 2481c4b

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

benchmarks/inductor_backends/cutlass.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34

45
os.environ["TORCH_LOGS"] = "inductor"
@@ -32,6 +33,7 @@
3233
UNITS = {
3334
"name": "",
3435
"forward_time": " (us)",
36+
"teraflops": " (TFLOPS)",
3537
"compilation_time": " (s)",
3638
}
3739
PERF_OVER_ATEN_STR: str = "perf_over_aten (%)"
@@ -75,7 +77,7 @@
7577

7678

7779
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
78-
return do_bench(lambda: func(*args, **kwargs)) * 1e3
80+
return do_bench(lambda: func(*args, **kwargs), warmup=100, rep=10000) * 1e3
7981

8082

8183
@dataclass(frozen=True, kw_only=True)
@@ -162,6 +164,7 @@ def name(self) -> str:
162164
class ExperimentResults:
163165
name: str
164166
forward_time: float
167+
teraflops: float
165168
compilation_time: float
166169

167170
def asdict(self):
@@ -211,7 +214,10 @@ def run_single_experiment_group(
211214
for config in group_config.experiments:
212215
torch._dynamo.reset()
213216
torch._inductor.utils.clear_inductor_caches()
214-
compiled_op = torch.compile(op, fullgraph=True, options=config.to_options())
217+
compiled_op = torch.compile(
218+
op,
219+
options=config.to_options(),
220+
)
215221

216222
start_time = time.perf_counter()
217223
try:
@@ -227,6 +233,7 @@ def run_single_experiment_group(
227233
ExperimentResults(
228234
name=config.name(),
229235
forward_time=float("inf"),
236+
teraflops=0.0,
230237
compilation_time=float("inf"),
231238
)
232239
)
@@ -238,10 +245,18 @@ def run_single_experiment_group(
238245
*inputs,
239246
)
240247

248+
flops = calculate_flops(
249+
group_config.op_name,
250+
group_config.shape,
251+
group_config.batch_size,
252+
)
253+
teraflops = flops / (forward_time * 1e-6) / 1e12
254+
241255
results.append(
242256
ExperimentResults(
243257
name=config.name(),
244258
forward_time=forward_time,
259+
teraflops=teraflops,
245260
compilation_time=compilation_time,
246261
)
247262
)
@@ -336,6 +351,20 @@ def calculate_table_data(results: list[ExperimentResults]) -> dict:
336351
return table_data
337352

338353

354+
def calculate_flops(op_name: str, shape: tuple[int, int, int], batch_size: int) -> int:
355+
"""
356+
Calculate the number of floating point operations based on operation type and shape.
357+
"""
358+
M, N, K = shape
359+
360+
if op_name == "bmm":
361+
return 2 * batch_size * M * N * K
362+
elif op_name == "addmm":
363+
return 2 * M * N * K + M * N
364+
else:
365+
return 2 * M * N * K
366+
367+
339368
def get_printable_results(experiment_groups: list[ExperimentGroup]) -> list[str]:
340369
edge_over_aten = defaultdict(list)
341370
output = []
@@ -390,8 +419,10 @@ def main():
390419
results.append(
391420
ExperimentGroup(config=group_config, results=group_results),
392421
)
393-
log.info(f"\nINTERMEDIATE results: {i}/{len(configs)}") # noqa: G004
394-
log.info(get_printable_results(results))
422+
sys.stderr.write(
423+
f"\nINTERMEDIATE results: {i + 1}/{len(configs)} \n"
424+
+ get_printable_results(results)
425+
)
395426
print("\nFINAL results...")
396427
print(get_printable_results(results))
397428

0 commit comments

Comments
 (0)