You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _sources/deeplearning_operators/gemv.md.txt
+11-11Lines changed: 11 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,4 @@
1
-
General Matrix-Vector Multiplication (GEMV)
1
+
# General Matrix-Vector Multiplication (GEMV)
2
2
===========================================
3
3
4
4
<div style="text-align: left;">
@@ -16,7 +16,7 @@ Example code can be found at `examples/gemv/example_gemv.py`.
16
16
17
17
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`.
18
18
19
-
# Triton implementation
19
+
## Triton Implementation
20
20
When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`.
21
21
22
22
A simple Triton kernel for GEMV might look like this:
@@ -39,7 +39,7 @@ def _gemv_naive(
39
39
40
40
`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control.
41
41
42
-
# Naive Implementation in TileLang
42
+
## Naive Implementation in TileLang
43
43
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example:
44
44
45
45
```python
@@ -120,7 +120,7 @@ In this design, the first 128 threads act as the data producer and the last 128
120
120
121
121
At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower.
122
122
123
-
# More concurrency
123
+
## More Concurrency
124
124
125
125
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA.
126
126
@@ -163,7 +163,7 @@ def naive_splitk_gemv(
163
163
164
164
By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS.
165
165
166
-
## Customizing Parallelism in K Dimension
166
+
### Customizing Parallelism in K Dimension
167
167
If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration:
168
168
169
169
```python
@@ -207,9 +207,9 @@ def splitk_gemv(
207
207
```
208
208
209
209
210
-
# Vectorized Reads
210
+
## Vectorized Reads
211
211
212
-
GEMV is less computation intensive than GEMM as the computation intensity and memory throuput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
212
+
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
213
213
214
214
```python
215
215
def splitk_gemv_vectorized(
@@ -255,7 +255,7 @@ def splitk_gemv_vectorized(
255
255
With vectorized read, now the kernel finishs in **~0.0084 ms**, which is getting close to cuBLAS performance.
256
256
257
257
258
-
# `tvm_thread_allreduce` Instead of `atomicAdd`
258
+
## `tvm_thread_allreduce` Instead of `atomicAdd`
259
259
260
260
[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`:
With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS!
317
317
318
-
# Autotune
318
+
## Autotune
319
319
320
320
`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations:
0 commit comments