@@ -18,6 +18,7 @@ Classes
18
18
.. autoapisummary ::
19
19
20
20
tilelang.intrinsics.mfma_macro_generator.MatrixCoreIntrinEmitter
21
+ tilelang.intrinsics.mfma_macro_generator.MatrixCorePreshuffleIntrinEmitter
21
22
22
23
23
24
Module Contents
@@ -159,3 +160,92 @@ Module Contents
159
160
.. py :method :: stmatrix(C_local_buf, C_buf, pid_m = None , pid_n = None )
160
161
161
162
163
+ .. py :class :: MatrixCorePreshuffleIntrinEmitter(a_dtype = ' float16' , b_dtype = ' float16' , accum_dtype = ' float16' , a_transposed = False , b_transposed = False , block_row_warps = 2 , block_col_warps = 2 , warp_row_tiles = 8 , warp_col_tiles = 8 , chunk = 16 , reduce_k = 1 , num_elems_per_byte = 1 , k_pack = None , is_m_first = False , a_preshuffle = False , b_preshuffle = False )
164
+
165
+ Bases: :py:obj: `MatrixCoreIntrinEmitter `
166
+
167
+
168
+ To eliminate Python syntax within TIR Macro.
169
+
170
+
171
+ .. py :attribute :: a_dtype
172
+ :value: 'float16'
173
+
174
+
175
+
176
+ .. py :attribute :: b_dtype
177
+ :value: 'float16'
178
+
179
+
180
+
181
+ .. py :attribute :: accum_dtype
182
+ :value: 'float16'
183
+
184
+
185
+
186
+ .. py :attribute :: a_transposed
187
+ :value: False
188
+
189
+
190
+
191
+ .. py :attribute :: b_transposed
192
+ :value: False
193
+
194
+
195
+
196
+ .. py :attribute :: block_row_warps
197
+ :value: 2
198
+
199
+
200
+
201
+ .. py :attribute :: block_col_warps
202
+ :value: 2
203
+
204
+
205
+
206
+ .. py :attribute :: warp_row_tiles
207
+ :value: 8
208
+
209
+
210
+
211
+ .. py :attribute :: warp_col_tiles
212
+ :value: 8
213
+
214
+
215
+
216
+ .. py :attribute :: chunk
217
+ :value: 16
218
+
219
+
220
+
221
+ .. py :attribute :: warp_rows
222
+ :value: 0
223
+
224
+
225
+
226
+ .. py :attribute :: warp_cols
227
+ :value: 0
228
+
229
+
230
+
231
+ .. py :attribute :: reduce_k
232
+ :value: 1
233
+
234
+
235
+
236
+ .. py :attribute :: threads
237
+ :value: 256
238
+
239
+
240
+
241
+ .. py :attribute :: num_elems_per_byte
242
+ :value: 1
243
+
244
+
245
+
246
+ .. py :method :: ldmatrix_a(A_local_buf, A_buf, ki, rk = 0 , pid_m = None , pid_n = None )
247
+
248
+
249
+ .. py :method :: ldmatrix_b(B_local_buf, B_buf, ki, rk = 0 , pid_m = None , pid_n = None )
250
+
251
+
0 commit comments