Skip to content

Commit a2c7b70

Browse files
Update docs
1 parent 453a501 commit a2c7b70

File tree

26 files changed

+623
-343
lines changed

26 files changed

+623
-343
lines changed

_sources/autoapi/tilelang/env/index.rst.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ Module Contents
250250
.. py:attribute:: TILELANG_CLEAR_CACHE
251251
252252
253+
.. py:attribute:: TILELANG_USE_GEMM_V1
254+
255+
253256
.. py:attribute:: TILELANG_AUTO_TUNING_CPU_UTILITIES
254257
255258
@@ -277,6 +280,15 @@ Module Contents
277280
.. py:method:: is_print_on_compilation_enabled()
278281
279282
283+
.. py:method:: use_gemm_v1()
284+
285+
Return True if GEMM v1 should be used based on env.
286+
287+
Controlled by `TILELANG_USE_GEMM_V1`. Truthy values are one of
288+
{"1", "true", "yes", "on"} (case-insensitive).
289+
290+
291+
280292
.. py:data:: env
281293
282294
.. py:data:: CUDA_HOME

_sources/autoapi/tilelang/intrinsics/mma_layout/index.rst.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Functions
3535
tilelang.intrinsics.mma_layout.ldmatrix_32x16_to_shared_16x32_layout_a
3636
tilelang.intrinsics.mma_layout.ldmatrix_32x16_to_shared_16x32_layout_b
3737
tilelang.intrinsics.mma_layout.mma_store_32x8_to_shared_16x16_layout
38+
tilelang.intrinsics.mma_layout.mma_store_32x2_to_shared_8x8_layout_fp64
3839
tilelang.intrinsics.mma_layout.shared_16x8_to_mma_a_32x4_layout
3940
tilelang.intrinsics.mma_layout.shared_16x8_to_mma_a_32x4_layout_trans
4041
tilelang.intrinsics.mma_layout.shared_16x8_to_mma_b_32x4_layout
@@ -76,6 +77,8 @@ Module Contents
7677
7778
.. py:function:: mma_store_32x8_to_shared_16x16_layout(thread_id, local_id)
7879
80+
.. py:function:: mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id)
81+
7982
.. py:function:: shared_16x8_to_mma_a_32x4_layout(i, j)
8083
8184
.. py:function:: shared_16x8_to_mma_a_32x4_layout_trans(i, j)

_sources/autoapi/tilelang/intrinsics/utils/index.rst.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Functions
1414
tilelang.intrinsics.utils.shared_16x32_to_mma_32x16_layout
1515
tilelang.intrinsics.utils.shared_32x16_to_mma_32x16_layout
1616
tilelang.intrinsics.utils.mma_store_index_map
17+
tilelang.intrinsics.utils.mma_store_index_map_fp64
1718
tilelang.intrinsics.utils.mfma_store_index_map
1819
tilelang.intrinsics.utils.get_mma_micro_size
1920

@@ -31,6 +32,8 @@ Module Contents
3132
3233
.. py:function:: mma_store_index_map(thread_id, local_id)
3334
35+
.. py:function:: mma_store_index_map_fp64(thread_id, local_id)
36+
3437
.. py:function:: mfma_store_index_map(thread_id, local_id)
3538
3639
.. py:function:: get_mma_micro_size(dtype)

_sources/autoapi/tilelang/intrinsics/wgmma_macro_generator/index.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ Module Contents
104104

105105

106106

107-
.. py:method:: wgmma(A_buf, B_buf, C_local_buf, clear_accum = False, wg_wait = 0)
107+
.. py:method:: wgmma(A_region, B_region, C_region, clear_accum = False, wg_wait = 0)
108108
109109
110-
.. py:method:: wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum = False, wg_wait = 0)
110+
.. py:method:: wgmma_rs(A_region, B_region, C_region, clear_accum = False, wg_wait = 0)
111111
112112
113113
.. py:method:: make_mma_load_layout(local_buf, matrix = 'A')

_sources/autoapi/tilelang/language/builtin/index.rst.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,15 +383,19 @@ Module Contents
383383
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
384384
WGMMA operations by issuing an empty inline assembly barrier on every register.
385385

386-
:param buffer_or_ptr: Buffer | PrimExpr
387-
Either a buffer representing the accumulator fragment or a pointer expression.
386+
:param buffer_or_ptr: Buffer | BufferLoad | BufferRegion | PrimExpr
387+
A buffer representing the accumulator fragment, a buffer load/region
388+
that identifies a starting element within the fragment, or a pointer expression
389+
(e.g., tvm_access_ptr/address_of/typed Var).
388390
:param offset: int | PrimExpr
389391
Element offset from the start of the accumulator fragment.
390392
:param num_regs: int | PrimExpr | None
391393
Number of 32-bit registers to fence. If None and a Buffer is provided, it will be
392394
derived from the buffer shape and dtype.
393395
:param dtype: str | None
394-
Data type string of the accumulator elements. Required when passing a pointer.
396+
Data type string of the accumulator elements. When passing a buffer or
397+
buffer-derived expression, dtype is inferred. It is required only when
398+
passing a raw pointer expression that cannot be inferred.
395399

396400
:returns: A handle to the warpgroup fence operation.
397401
:rtype: tir.Call

_sources/autoapi/tilelang/language/gemm/index.rst.txt

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,72 +31,12 @@ Module Contents
3131

3232
.. py:function:: gemm_v1(A, B, C, transpose_A = False, transpose_B = False, policy = GemmWarpPolicy.Square, clear_accum = False, k_pack = 1, wg_wait = 0, mbar = None)
3333
34-
Perform a General Matrix Multiplication (GEMM) operation.
35-
36-
This function computes C = A @ B where A and B can optionally be transposed.
37-
The operation supports various warp policies and accumulation modes.
38-
39-
:param A: First input matrix
40-
:type A: Union[tir.Buffer, tir.Var]
41-
:param B: Second input matrix
42-
:type B: Union[tir.Buffer, tir.Var]
43-
:param C: Output matrix for results
44-
:type C: Union[tir.Buffer, tir.Var]
45-
:param transpose_A: Whether to transpose matrix A. Defaults to False.
46-
:type transpose_A: bool, optional
47-
:param transpose_B: Whether to transpose matrix B. Defaults to False.
48-
:type transpose_B: bool, optional
49-
:param policy: Warp execution policy. Defaults to GemmWarpPolicy.Square.
50-
:type policy: GemmWarpPolicy, optional
51-
:param clear_accum: Whether to clear accumulator before computation. Defaults to False.
52-
:type clear_accum: bool, optional
53-
:param k_pack: Number of k dimensions packed into a single warp. Defaults to 1.
54-
:type k_pack: int, optional
55-
:param wg_wait: Warp group wait count. Defaults to 0.
56-
On hopper it is equivalent to `wgmma.wait_group.sync.aligned <wg_wait>` if wg_wait is not -1
57-
On sm100, `wg_wait` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier)` will be appended if wg_wait is 0.
58-
:type wg_wait: int, optional
59-
:param mbar: mbarrier for TCGEN5MMA synchronization
60-
:type mbar: tir.Buffer, optional
61-
62-
:returns: A handle to the GEMM operation
63-
:rtype: tir.Call
64-
65-
:raises AssertionError: If the K dimensions of matrices A and B don't match
34+
GEMM v1: use op tl.gemm.
6635

6736

6837
.. py:function:: gemm_v2(A, B, C, transpose_A = False, transpose_B = False, policy = GemmWarpPolicy.Square, clear_accum = False, k_pack = 1, wg_wait = 0, mbar = None)
6938
70-
Perform a General Matrix Multiplication (GEMM) operation.
71-
72-
This function computes C = A @ B where A and B can optionally be transposed.
73-
The operation supports various warp policies and accumulation modes.
74-
75-
:param A: First input matrix
76-
:type A: Union[tir.Buffer, tir.Var]
77-
:param B: Second input matrix
78-
:type B: Union[tir.Buffer, tir.Var]
79-
:param C: Output matrix for results
80-
:type C: Union[tir.Buffer, tir.Var]
81-
:param transpose_A: Whether to transpose matrix A. Defaults to False.
82-
:type transpose_A: bool, optional
83-
:param transpose_B: Whether to transpose matrix B. Defaults to False.
84-
:type transpose_B: bool, optional
85-
:param policy: Warp execution policy. Defaults to GemmWarpPolicy.Square.
86-
:type policy: GemmWarpPolicy, optional
87-
:param clear_accum: Whether to clear accumulator before computation. Defaults to False.
88-
:type clear_accum: bool, optional
89-
:param k_pack: Number of k dimensions packed into a single warp. Defaults to 1.
90-
:type k_pack: int, optional
91-
:param wg_wait: Warp group wait count. Defaults to 0.
92-
:type wg_wait: int, optional
93-
:param mbar: mbarrier for TCGEN5MMA synchronization
94-
:type mbar: tir.Buffer, optional
95-
96-
:returns: A handle to the GEMM operation
97-
:rtype: tir.Call
98-
99-
:raises AssertionError: If the K dimensions of matrices A and B don't match
39+
GEMM v2: use op tl.gemm_py.
10040

10141

10242
.. py:data:: gemm

_sources/autoapi/tilelang/layout/swizzle/index.rst.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Module Contents
3737
3838
.. py:function:: make_full_bank_swizzled_layout(*args)
3939
40-
:param args: buffer or (stride, continuous, element_size)
40+
:param args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
4141

4242
.. rubric:: Examples
4343

@@ -47,7 +47,7 @@ Module Contents
4747

4848
.. py:function:: make_half_bank_swizzled_layout(*args)
4949
50-
:param args: buffer or (stride, continuous, element_size)
50+
:param args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
5151

5252
.. rubric:: Examples
5353

@@ -57,7 +57,7 @@ Module Contents
5757

5858
.. py:function:: make_quarter_bank_swizzled_layout(*args)
5959
60-
:param args: buffer or (stride, continuous, element_size)
60+
:param args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
6161

6262
.. rubric:: Examples
6363

@@ -67,7 +67,7 @@ Module Contents
6767

6868
.. py:function:: make_linear_layout(*args)
6969
70-
:param args: buffer or (stride, continuous)
70+
:param args: buffer/BufferLoad/BufferRegion or (stride, continuous)
7171

7272
.. rubric:: Examples
7373

_sources/autoapi/tilelang/tileop/gemm/gemm_base/index.rst.txt

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,13 @@ Module Contents
9898

9999

100100

101-
.. py:property:: APtr
102-
:type: tvm.tir.PrimExpr
101+
.. py:property:: ARegion
103102
104103
104+
.. py:property:: BRegion
105105
106-
.. py:property:: BPtr
107-
:type: tvm.tir.PrimExpr
108-
109-
110-
111-
.. py:property:: CPtr
112-
:type: tvm.tir.PrimExpr
113106
107+
.. py:property:: CRegion
114108
115109
116110
.. py:property:: stride_A
@@ -161,3 +155,31 @@ Module Contents
161155
.. py:property:: C_coords
162156
163157
158+
.. py:method:: get_region_base_offsets(region)
159+
160+
Get the base offset (start index) for each dimension from a BufferRegion.
161+
162+
For example, if region is A_shared[ko % 2, 0:128, 0:64],
163+
this returns [ko % 2, 0, 0]
164+
165+
:param region: BufferRegion object
166+
167+
:returns: List of PrimExpr representing the base offset for each dimension
168+
169+
170+
171+
.. py:property:: A_base_offsets
172+
173+
Get base offsets for each dimension of A region
174+
175+
176+
.. py:property:: B_base_offsets
177+
178+
Get base offsets for each dimension of B region
179+
180+
181+
.. py:property:: C_base_offsets
182+
183+
Get base offsets for each dimension of C region
184+
185+

_sources/autoapi/tilelang/tileop/gemm/index.rst.txt

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -94,80 +94,58 @@ Package Contents
9494
Bases: :py:obj:`tvm.ir.base.Node`, :py:obj:`tvm.runtime.Scriptable`
9595

9696

97-
.. py:attribute:: A
98-
:type: tvm.tir.Buffer
97+
.. py:property:: A
9998
10099
101-
.. py:attribute:: B
102-
:type: tvm.tir.Buffer
100+
.. py:property:: B
103101
104102
105-
.. py:attribute:: C
106-
:type: tvm.tir.Buffer
103+
.. py:property:: C
107104
108105
109-
.. py:attribute:: APtr
110-
:type: tvm.tir.PrimExpr
106+
.. py:property:: APtr
111107
112108
113-
.. py:attribute:: BPtr
114-
:type: tvm.tir.PrimExpr
109+
.. py:property:: BPtr
115110
116111
117-
.. py:attribute:: CPtr
118-
:type: tvm.tir.PrimExpr
112+
.. py:property:: CPtr
119113
120114
121-
.. py:attribute:: M
122-
:type: int
115+
.. py:property:: M
123116
124117
125-
.. py:attribute:: N
126-
:type: int
118+
.. py:property:: N
127119
128120
129-
.. py:attribute:: K
130-
:type: int
121+
.. py:property:: K
131122
132123
133-
.. py:attribute:: trans_A
134-
:type: bool
124+
.. py:property:: trans_A
135125
136126
137-
.. py:attribute:: trans_B
138-
:type: bool
127+
.. py:property:: trans_B
139128
140129
141-
.. py:attribute:: stride_A
142-
:type: int
130+
.. py:property:: stride_A
143131
144132
145-
.. py:attribute:: stride_B
146-
:type: int
133+
.. py:property:: stride_B
147134
148135
149-
.. py:attribute:: offset_A
150-
:type: int
136+
.. py:property:: offset_A
151137
152138
153-
.. py:attribute:: offset_B
154-
:type: int
139+
.. py:property:: offset_B
155140
156141
157-
.. py:attribute:: clear_accum
158-
:type: bool
142+
.. py:property:: clear_accum
159143
160144
161-
.. py:attribute:: k_pack
162-
:type: int
145+
.. py:property:: k_pack
163146
164147
165-
.. py:attribute:: wg_wait
166-
:type: int
167-
168-
169-
.. py:attribute:: policy
170-
:type: tilelang.ir.GemmWarpPolicy
148+
.. py:property:: wg_wait
171149
172150
173151
.. py:method:: infer_layout(target, thread_nums)

0 commit comments

Comments
 (0)