Skip to content

Low matmul performance #12

@AKKamath

Description

@AKKamath

Hi! Thanks for open-sourcing these batch-invariant implementations!

I'm running them on an H200 and noticing quite low performance, not matching that mentioned in the blog post.

Below is the plotted performance of the batch invariant matmul vs Pytorch's matmul.
Image

I seem to get up to 750 TFLOPS from Pytorch, which is close to the peak device FLOPS. The batch invariant barely manages to reach 300 TFLOPS, where it should be reaching up to 600, according to the blog post. Below is the exact code I'm using to benchmark. Is there something I've done wrong, or that should be changed?

import torch
from batch_invariant_ops import matmul_persistent
def bench_perf(matmul_func, B, D=4096, iterations=50):
    a = torch.randn(B, D, device='cuda', dtype=torch.float16)
    b = torch.randn(D, D, device='cuda', dtype=torch.float16)

    # Warm-up
    for _ in range(5):
        _ = matmul_func(a, b)

    torch.cuda.synchronize()
    import time
    start = time.perf_counter()
    for _ in range(iterations):
        _ = matmul_func(a, b)
    torch.cuda.synchronize()
    end = time.perf_counter()

    avg_time = (end - start) / iterations
    tflops = 2 * B * D * D / (avg_time * 1e12)
    print(f"Avg Time: {avg_time*1000:.2f} ms, TFLOPS: {tflops:.2f}")

BATCHES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 1024, 2048, 4096, 8192, 16384]

torch_tflops = {}
bi_tflops = {}

# Benchmark performance
for batch_size in BATCHES:
    torch_tflops[batch_size] = bench_perf(torch.mm, B=batch_size)
    bi_tflops[batch_size] = bench_perf(matmul_persistent, B=batch_size)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions