-
Notifications
You must be signed in to change notification settings - Fork 72
Open
Description
Hi! Thanks for open-sourcing these batch-invariant implementations!
I'm running them on an H200 and noticing quite low performance, not matching that mentioned in the blog post.
Below is the plotted performance of the batch invariant matmul vs Pytorch's matmul.

I seem to get up to 750 TFLOPS from Pytorch, which is close to the peak device FLOPS. The batch invariant barely manages to reach 300 TFLOPS, where it should be reaching up to 600, according to the blog post. Below is the exact code I'm using to benchmark. Is there something I've done wrong, or that should be changed?
import torch
from batch_invariant_ops import matmul_persistent
def bench_perf(matmul_func, B, D=4096, iterations=50):
a = torch.randn(B, D, device='cuda', dtype=torch.float16)
b = torch.randn(D, D, device='cuda', dtype=torch.float16)
# Warm-up
for _ in range(5):
_ = matmul_func(a, b)
torch.cuda.synchronize()
import time
start = time.perf_counter()
for _ in range(iterations):
_ = matmul_func(a, b)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time = (end - start) / iterations
tflops = 2 * B * D * D / (avg_time * 1e12)
print(f"Avg Time: {avg_time*1000:.2f} ms, TFLOPS: {tflops:.2f}")
BATCHES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 1024, 2048, 4096, 8192, 16384]
torch_tflops = {}
bi_tflops = {}
# Benchmark performance
for batch_size in BATCHES:
torch_tflops[batch_size] = bench_perf(torch.mm, B=batch_size)
bi_tflops[batch_size] = bench_perf(matmul_persistent, B=batch_size)
Metadata
Metadata
Assignees
Labels
No labels