Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MultiHeadAttentionDecodeBenchmark,
)
from .gemm import GemmBenchmark, MatMulBenchmark
from .gemv import GemvBenchmark
from .grouped_gemm import (
GroupedGemmBenchmark,
GroupedGemmNNBenchmark,
Expand All @@ -36,6 +37,7 @@
"GroupQueryAttentionFwdBenchmark",
"GroupQueryAttentionBwdBenchmark",
"GemmBenchmark",
"GemvBenchmark",
"MultiHeadAttentionDecodeBenchmark",
"GroupQueryAttentionDecodeBenchmark",
"MultiHeadLatentAttentionDecodeBenchmark",
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/gemv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gemv import GemvBenchmark

__all__ = ["GemvBenchmark"]
39 changes: 39 additions & 0 deletions benchmarks/gemv/gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Tuple

import torch

from benchmarks.benchmark import Benchmark
from top.ops import GemvOp


class GemvBenchmark(Benchmark):

op_type = GemvOp

def __init__(self, n: int, k: int, dtype: torch.dtype):
self.n = n
self.k = k
self.dtype = dtype

@property
def total_flops(self) -> float:
return 2.0 * self.n * self.k

@property
def total_memory(self) -> int:
return (self.k + self.k * self.n + self.n) * self.dtype.itemsize

def gen_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
shape_a = (self.k,)
a = torch.randn(*shape_a, device='cuda', dtype=self.dtype)
shape_b = (self.n, self.k)
b = torch.randn(*shape_b, device='cuda', dtype=self.dtype)
return a, b

def ref_program(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# return torch.mv(b, a)
return b @ a

def baseline_profile(self, *inputs, warmup=100, rep=10, device="cuda:0") -> None:
return super().baseline_profile(
self.ref_program, *inputs, backend="torch", warmup=warmup, rep=rep, device=device)
28 changes: 28 additions & 0 deletions tests/ops/test_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import pytest

from benchmarks import GemvBenchmark
from top.ops import GemvOp


@pytest.mark.parametrize(
"n, k, dtype, tune",
[
(1024, 1024, torch.float16, False),
(7168, 16384, torch.float16, True),
(18432, 7168, torch.float16, True),
],
)
def test_gemv(n: int, k: int, dtype: torch.dtype, tune: bool) -> None:
op = GemvOp(n, k, dtype=dtype, tune=tune)
benchmark = GemvBenchmark(n, k, dtype)

inputs = benchmark.gen_inputs()

benchmark.check(op, *inputs, atol=1e-3, rtol=1e-3)
benchmark.profile(op, *inputs)


if __name__ == "__main__":
# Run tests with pytest
pytest.main([__file__, "-vvs"])
3 changes: 3 additions & 0 deletions top/kernels/gemv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gemv import GemvKernel

__all__ = ["GemvKernel"]
142 changes: 142 additions & 0 deletions top/kernels/gemv/gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import itertools
from typing import Callable, Optional

import tilelang
import tilelang.language as T
import torch

from top.kernels.kernel import Kernel
from top.utils import get_sm_version, str2dtype

__all__ = [
'GemvKernel',
]


def _gemv_kernel(n: int, k: int, dtype: str = "float16") -> Callable:
accum_dtype = "float"

@tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"])
def _gemv_func(
block_n: int,
reduce_threads: int,
) -> Callable:

max_transaction_size_in_bits = 128
tile_k = max_transaction_size_in_bits // (str2dtype[dtype].itemsize * 8)
block_k = reduce_threads * tile_k

@T.prim_func
def _gemv_main(
a: T.Buffer((k,), dtype),
b: T.Buffer((n, k), dtype),
c: T.Buffer((n,), dtype),
):
with T.Kernel(T.ceildiv(n, block_n), threads=(block_n, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
a_local = T.alloc_local((tile_k,), dtype)
b_local = T.alloc_local((tile_k,), dtype)
c_accum = T.alloc_local((1,), accum_dtype)

T.clear(c_accum)
for bk in T.serial(T.ceildiv(k, block_k)):
for _k in T.vectorized(tile_k):
a_local[_k] = a[bk * block_k + tk * tile_k + _k]
b_local[_k] = b[bn * block_n + tn, bk * block_k + tk * tile_k + _k]
for _k in T.serial(tile_k):
c_accum[0] += a_local[_k].astype(accum_dtype) * b_local[_k].astype(
accum_dtype)
c_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
c_accum[0],
True,
c_reduced[0],
tk,
dtype="handle",
))

c[bn * block_n + tn] = c_reduced[0]

return _gemv_main

return _gemv_func


@torch.library.custom_op("top::gemv_wrapped_kernel", mutates_args=())
def _gemv_wrapped_kernel(
n: int,
k: int,
dtype: str,
block_n: int,
reduce_threads: int,
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return _gemv_kernel(n, k, dtype)(block_n, reduce_threads)(a, b)


@_gemv_wrapped_kernel.register_fake
def _(n: int, k: int, # noqa: U100
dtype: str, block_n: int, reduce_threads: int, # noqa: U100
*inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: # noqa: U100
return torch.empty((n,), dtype=inputs[0].dtype, device=inputs[0].device)


class GemvKernel(Kernel):
supported_archs: list[int] = [90]

def __init__(self,
n: int,
k: int,
dtype: torch.dtype,
config: Optional[dict] = None,
tune: bool = False) -> None:
super().__init__()
self.n = n
self.k = k
self.dtype = dtype

self.kernel = _gemv_kernel(n, k, self.dtype_str)

self.init_config(config, tune)

@property
def default_config(self) -> dict:
# From tilelang/examples/gemm/example_gemm_autotune.py
sm_version = get_sm_version()

if sm_version in {90}:
return {
"block_n": 32,
"reduce_threads": 8,
}

return {
"block_n": 128,
"reduce_threads": 32,
}

@property
def autotune_configs(self) -> list[dict]:
# From tilelang/examples/gemm/example_gemm_autotune.py
block_n = [64, 128, 256]
reduce_threads = [16, 32]
_configs = list(itertools.product(block_n, reduce_threads))

return [{
'block_n': c[0],
'reduce_threads': c[1],
} for c in _configs]

def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
a = a.flatten().contiguous()
return _gemv_wrapped_kernel(self.n, self.k, self.dtype_str, self.config["block_n"],
self.config["reduce_threads"], a, b)
2 changes: 2 additions & 0 deletions top/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp
from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp
from .gemm import GemmOp
from .gemv import GemvOp
from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp
from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp
from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp
Expand All @@ -23,6 +24,7 @@
"GroupQueryAttentionFwdOp",
"GroupQueryAttentionBwdOp",
"GemmOp",
"GemvOp",
"MultiHeadAttentionDecodeWithKVCacheOp",
"MultiHeadAttentionDecodePagedWithKVCacheOp",
"GroupQueryAttentionDecodeWithKVCacheOp",
Expand Down
34 changes: 34 additions & 0 deletions top/ops/gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Dict, Optional

import torch

from top.kernels.gemv import GemvKernel
from top.kernels.kernel import Kernel

from .op import Op

__all__ = ['GemvOp']


class GemvOp(Op):

def __init__(self,
n: int,
k: int,
dtype: torch.dtype = torch.float16,
kernel_map: Optional[Dict[str, Kernel]] = None,
tune: bool = False) -> None:
self.N = n
self.K = k

self.dtype = dtype

self.dispatch_kernel(kernel_map)
self.kernel = self.kernel_map["gemv_kernel"](n, k, self.dtype, tune=tune)

@property
def default_kernel_map(self) -> Dict[str, Kernel]:
return {"gemv_kernel": GemvKernel}

def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return self.kernel(a, b)