Skip to content

Commit a32009b

Browse files
Alex4210987Your Name
andauthored
[CI]Add norm and layout_plot (#534)
* [CI]Add norm and layout_plot * fix lint * Remove obsolete test files for RMS normalization and plot layout, streamlining the testing suite. * Add make_mma_load_base_layout function to create MMA result layouts - Introduced a new function `make_mma_load_base_layout` for generating layout functions for storing MMA results in fragment buffers. - Added detailed docstring explaining parameters, return values, and potential exceptions. - Implemented logic for handling different data types and matrix configurations, including assertions for input validation. - Defined internal functions for mapping fragment indices to threads and local indices, enhancing the layout functionality. * Enhance MMA load test with additional imports and functionality - Added imports for `tilelang.language`, `Literal`, `Callable`, `DataType`, `IndexMap`, and `get_mma_micro_size` to support extended functionality. - Improved the `make_mma_load_base_layout` function by ensuring it can handle various data types and configurations. - Updated the test function `test_mma_load_base_layout` to validate the layout for float16 matrix A. * Fix formatting in test_fragment_mma_load_a.py by adding a blank line for improved readability. * Add RMS normalization functions to test_rms_norm.py - Introduced `rms_norm` and `rms_norm_splitk` functions for RMS normalization, enhancing the testing capabilities. - Implemented kernel functions with shared memory allocation and parallel processing for improved performance. - Updated the test function to validate the new RMS normalization implementations. * Add reference program for RMS normalization in test_rms_norm.py - Introduced `ref_program` function to provide a reference implementation for RMS normalization. - This addition enhances the testing framework by allowing comparisons against a known reference output. * Enhance RMS normalization tests with additional imports and formatting - Added import for `tilelang.language` to support extended functionality in `test_rms_norm.py`. - Improved code readability by adding blank lines for better separation of code sections. * Update RMS normalization test parameters and enhance layout plotting - Increased matrix dimensions in `test_rms_norm` to 8192 for improved performance testing. - Removed obsolete test functions in `test_fragment_mma_load_a.py` to streamline the test suite. - Enhanced layout plotting functionality by ensuring proper visualization of base, warp, and block layouts in `test_fragment_mma_load_a.py`. * Refactor RMS normalization test parameters and improve layout plotting readability - Simplified the parameters in `test_rms_norm` by removing `blk_k` for clarity. - Enhanced code readability in `test_fragment_mma_load_a.py` by adjusting the formatting of the `block_layout` definition and removing the unused `warp_cols` variable. * Enhance RMS normalization with split-k implementation and additional profiling - Added a new function `test_rms_norm_splitk` to test the split-k variant of RMS normalization. - Updated the main RMS normalization script to include profiling for the split-k implementation. - Ensured all checks pass with appropriate latency measurements for both reference and tile-lang implementations. * Remove obsolete test file `test_fragment_mma_load_a.py` to streamline the test suite. * Refactor `rms_norm.py` to streamline benchmarking output and remove redundant code. Comment out the `plot_layout` call in `fragment_mma_load_a.py` for clarity. * Refactor `test_rms_norm.py` by removing redundant test function `test_rms_norm_splitk` to streamline the test suite and improve clarity. --------- Co-authored-by: Your Name <you@example.com>
1 parent 7eef7f2 commit a32009b

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

examples/norm/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def ref_program(x):
8181
latency = profiler.do_bench(ref_program, warmup=500)
8282
print("Ref: {:.2f} ms".format(latency))
8383
latency = profiler.do_bench(warmup=500)
84-
print("Tile-lang: {:.2f} ms".format(latency))
84+
print("Tile-lang: {:.2f} ms".format(latency))

examples/norm/test_rms_norm.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import tilelang
6+
import tilelang.language as T
7+
8+
9+
def rms_norm_splitk(M, N, blk_m, blk_k):
10+
dtype = "float"
11+
12+
@T.prim_func
13+
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
14+
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
15+
A_shared = T.alloc_shared((blk_m, blk_k), dtype)
16+
A_local = T.alloc_fragment((blk_m, blk_k), dtype)
17+
A_powsum = T.alloc_fragment((blk_m,), dtype)
18+
19+
num_k_step = T.ceildiv(N, blk_k)
20+
T.clear(A_local)
21+
for k in range(num_k_step):
22+
T.copy(A[bx * blk_m, k * blk_k], A_shared)
23+
for i, j in T.Parallel(blk_m, blk_k):
24+
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
25+
T.reduce_sum(A_local, A_powsum, dim=1)
26+
for i in T.Parallel(blk_m):
27+
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
28+
29+
for k in range(num_k_step):
30+
# reverse, better cache hit rate
31+
T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared)
32+
for i, j in T.Parallel(blk_m, blk_k):
33+
A_shared[i, j] *= A_powsum[i]
34+
T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k])
35+
36+
return main
37+
38+
39+
def rms_norm(M, N, blk_m):
40+
dtype = "float"
41+
42+
@T.prim_func
43+
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
44+
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
45+
A_shared = T.alloc_shared((blk_m, N), dtype)
46+
A_pow_local = T.alloc_fragment((blk_m, N), dtype)
47+
A_local = T.alloc_fragment((blk_m, N), dtype)
48+
A_powsum = T.alloc_fragment((blk_m,), dtype)
49+
50+
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
51+
T.copy(A_shared, A_local)
52+
for i, j in T.Parallel(blk_m, N):
53+
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
54+
T.reduce_sum(A_pow_local, A_powsum, dim=1)
55+
for i in T.Parallel(blk_m):
56+
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
57+
for i, j in T.Parallel(blk_m, N):
58+
A_local[i, j] *= A_powsum[i]
59+
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
60+
61+
return main
62+
63+
64+
def ref_program(x):
65+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12)
66+
67+
68+
def test_rms_norm():
69+
M, N, blk_m = 8192, 8192, 1
70+
program = rms_norm(M, N, blk_m)
71+
kernel = tilelang.compile(
72+
program,
73+
out_idx=-1,
74+
target="cuda",
75+
execution_backend="cython",
76+
pass_configs={"tl.disable_tma_lower": True})
77+
profiler = kernel.get_profiler()
78+
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)

0 commit comments

Comments
 (0)