Skip to content

Commit 154054d

Browse files
committed
add mma_tile_tex.cc
1 parent 39f7a3f commit 154054d

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed

kernels/cutlass/cute/mma_tile_tex.cc

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
/* makefile
2+
copy from: https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15#file-mma_tile_tex-cpp
3+
4+
render: build
5+
./build/main > mma_tile.tex
6+
xelatex --cnf-line=main_memory=12000000 --halt-on-error mma_tile.tex && rm -rf *.aux *.log *.out
7+
8+
build:
9+
cmake -B build
10+
cmake --build build
11+
12+
.PHONY: build
13+
14+
15+
*/
16+
17+
18+
#include "cute/tensor.hpp"
19+
#include "cutlass/numeric_types.h"
20+
21+
22+
void print_header() {
23+
const char* latex_header =
24+
"\\documentclass{article}\n"
25+
"\\usepackage[a4paper, margin=0.5cm]{geometry}\n"
26+
"\\usepackage{adjustbox}\n"
27+
"\\usepackage{graphicx}\n"
28+
"\\usepackage{lipsum}\n"
29+
"\\usepackage{tikz}\n"
30+
"\n"
31+
"\\begin{document}\n";
32+
printf("%s", latex_header);
33+
}
34+
35+
36+
void print_footer() {
37+
const char* latex_footer = "\\end{document}\n";
38+
printf("%s", latex_footer);
39+
}
40+
41+
42+
// Copy from mma_atom.hpp
43+
//
44+
// Modified to remove printing header and footder, hence allows printing
45+
// multiple MMAs per TEX file for easier comparisons.
46+
template <class AtomLayoutMNK,
47+
class ValLayoutMNK,
48+
class PermutationsMNK,
49+
class LayoutC, class ThrIDC,
50+
class LayoutA, class ThrIDA,
51+
class LayoutB, class ThrIDB>
52+
void
53+
print_mma(const char* name,
54+
const AtomLayoutMNK& atom_layout_mnk,
55+
const ValLayoutMNK& val_layout_mnk,
56+
const PermutationsMNK& permutations_mnk,
57+
LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
58+
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
59+
LayoutB const& B, ThrIDB const& TB) { // (n,k) -> (tid,vid) and tid -> thr_idx
60+
using namespace cute;
61+
62+
printf("\\begin{verbatim}\n");
63+
printf("\n%s\n\n", name);
64+
65+
printf(" AtomLayoutMNK: "); print(atom_layout_mnk); printf("\n");
66+
printf(" ValLayoutMNK: "); print(val_layout_mnk); printf("\n");
67+
printf("PermutationsMNK: "); print(permutations_mnk); printf("\n\n");
68+
69+
printf("LayoutC: "); print(C); printf("\n");
70+
printf(" ThrIDC: "); print(TC); printf("\n");
71+
printf("LayoutA: "); print(A); printf("\n");
72+
printf(" ThrIDA: "); print(TA); printf("\n");
73+
printf("LayoutB: "); print(B); printf("\n");
74+
printf(" ThrIDB: "); print(TB); printf("\n");
75+
printf("\\end{verbatim}\n");
76+
77+
// printf("\\begin{adjustbox}{max height=0.7\\textheight,max width=\\textwidth}%");
78+
printf("\\begin{adjustbox}{max height=0.7\\textheight,max width=\\textwidth}\n");
79+
printf("\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/"
80+
".style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n");
81+
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
82+
"{rgb,255:red,175;green,255;blue,175}",
83+
"{rgb,255:red,255;green,255;blue,175}",
84+
"{rgb,255:red,255;green,175;blue,175}",
85+
"{rgb,255:red,210;green,210;blue,255}",
86+
"{rgb,255:red,210;green,255;blue,210}",
87+
"{rgb,255:red,255;green,255;blue,210}",
88+
"{rgb,255:red,255;green,210;blue,210}"};
89+
90+
// C starting at 0,0
91+
for (int m = 0; m < size<0>(C); ++m) {
92+
for (int n = 0; n < size<1>(C); ++n) {
93+
int thrid = C(m,n) % size(TC);
94+
int val_idx = C(m,n) / size(TC);
95+
int thr_idx = TC(thrid);
96+
97+
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
98+
color_map[thr_idx % 8],
99+
m, n,
100+
thr_idx, val_idx);
101+
}
102+
}
103+
104+
// A starting at 0,-size<1>(A)-1
105+
for (int m = 0; m < size<0>(A); ++m) {
106+
for (int k = 0; k < size<1>(A); ++k) {
107+
int thrid = A(m,k) % size(TA);
108+
int val_idx = A(m,k) / size(TA);
109+
int thr_idx = TA(thrid);
110+
111+
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
112+
color_map[thr_idx % 8],
113+
m, k-1-size<1>(A),
114+
thr_idx, val_idx);
115+
}
116+
}
117+
118+
// B starting at -size<1>(B)-1,0
119+
for (int n = 0; n < size<0>(B); ++n) {
120+
for (int k = 0; k < size<1>(B); ++k) {
121+
int thrid = B(n,k) % size(TB);
122+
int val_idx = B(n,k) / size(TB);
123+
int thr_idx = TB(thrid);
124+
125+
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
126+
color_map[thr_idx % 8],
127+
k-1-size<1>(B), n,
128+
thr_idx, val_idx);
129+
}
130+
}
131+
132+
// A labels
133+
for (int m = 0, k = -1; m < size<0>(A); ++m) {
134+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m);
135+
}
136+
for (int k = 0, m = -1; k < size<1>(A); ++k) {
137+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k);
138+
}
139+
// B labels
140+
for (int n = 0, k = -1; n < size<0>(B); ++n) {
141+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n);
142+
}
143+
for (int k = 0, n = -1; k < size<1>(B); ++k) {
144+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k);
145+
}
146+
147+
printf("\\end{tikzpicture}\n\\end{adjustbox}%\n");
148+
}
149+
150+
151+
template <class TiledCopy,
152+
class LayoutS, class ThrIDS,
153+
class LayoutD, class ThrIDD>
154+
void
155+
print_copy(const char* name,
156+
TiledCopy& copy,
157+
LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx
158+
LayoutD const& D, ThrIDD const& TD) { // (m,n) -> (tid,vid) and tid -> thr_idx
159+
using namespace cute;
160+
161+
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
162+
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
163+
164+
assert(size<0>(S) == size<0>(D));
165+
assert(size<1>(S) == size<1>(D));
166+
167+
printf("\\begin{verbatim}\n");
168+
printf("\n%s\n\n", name);
169+
printf("LayoutCopy_TV: "); print(typename TiledCopy::TiledLayout_TV{}); printf("\n");
170+
printf(" ShapeTile_MN: "); print(typename TiledCopy::Tiler_MN{}); printf("\n\n");
171+
172+
printf(" LayoutS: "); print(S); printf("\n");
173+
printf(" ThrIDS: "); print(TS); printf("\n");
174+
printf(" LayoutD: "); print(D); printf("\n");
175+
printf(" ThrIDD: "); print(TD); printf("\n");
176+
printf("\\end{verbatim}\n");
177+
178+
printf("\\begin{adjustbox}{max height=0.7\\textheight,max width=\\textwidth}\n");
179+
printf("\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/"
180+
".style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n");
181+
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
182+
"{rgb,255:red,175;green,255;blue,175}",
183+
"{rgb,255:red,255;green,255;blue,175}",
184+
"{rgb,255:red,255;green,175;blue,175}",
185+
"{rgb,255:red,210;green,210;blue,255}",
186+
"{rgb,255:red,210;green,255;blue,210}",
187+
"{rgb,255:red,255;green,255;blue,210}",
188+
"{rgb,255:red,255;green,210;blue,210}"};
189+
190+
// S starting at 0,0
191+
for (int i = 0; i < size<0>(S); ++i) {
192+
for (int j = 0; j < size<1>(S); ++j) {
193+
int thrid = S(i,j) % size(TS);
194+
int val_idx = S(i,j) / size(TS);
195+
int thr_idx = TS(thrid);
196+
197+
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
198+
color_map[thr_idx % 8],
199+
i, j,
200+
thr_idx, val_idx);
201+
}
202+
}
203+
204+
// D starting at 0,size<1>(S)+3
205+
for (int i = 0; i < size<0>(D); ++i) {
206+
for (int j = 0; j < size<1>(D); ++j) {
207+
int thrid = D(i,j) % size(TD);
208+
int val_idx = D(i,j) / size(TD);
209+
int thr_idx = TD(thrid);
210+
211+
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
212+
color_map[thr_idx % 8],
213+
i + size<0>(S) + 3, j,
214+
thr_idx, val_idx);
215+
}
216+
}
217+
218+
// S Labels
219+
for (int i = 0, j = -1; i < size<0>(S); ++i) {
220+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
221+
}
222+
for (int j = 0, i = -1; j < size<1>(S); ++j) {
223+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
224+
}
225+
// D Labels
226+
for (int i = 0, j = -1; i < size<0>(S); ++i) {
227+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i + size<0>(S) + 3, j, i);
228+
}
229+
for (int j = 0, i = -1; j < size<1>(D); ++j) {
230+
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i + size<0>(S) + 3, j, j);
231+
}
232+
233+
printf("\\end{tikzpicture}\n\\end{adjustbox}%\n");
234+
}
235+
236+
237+
template <class MMA_Atom_Arch,
238+
class AtomLayoutMNK,
239+
class ValLayoutMNK,
240+
class PermutationsMNK>
241+
void print_mma_content(
242+
const char* name,
243+
cute::TiledMMA<MMA_Atom_Arch, AtomLayoutMNK, ValLayoutMNK, PermutationsMNK> const& mma) {
244+
printf("\n\\newpage\n");
245+
246+
auto layout_and_thrid_C = mma.get_layoutC_MN();
247+
auto layoutC_MN = cute::get<0>(layout_and_thrid_C);
248+
auto thrID_C = cute::get<1>(layout_and_thrid_C);
249+
250+
auto layout_and_thrid_A = mma.get_layoutA_MK();
251+
auto layoutA_MK = cute::get<0>(layout_and_thrid_A);
252+
auto thrID_A = cute::get<1>(layout_and_thrid_A);
253+
254+
auto layout_and_thrid_B = mma.get_layoutB_NK();
255+
auto layoutB_NK = cute::get<0>(layout_and_thrid_B);
256+
auto thrID_B = cute::get<1>(layout_and_thrid_B);
257+
258+
print_mma(name,
259+
AtomLayoutMNK{},
260+
ValLayoutMNK{},
261+
PermutationsMNK{},
262+
layoutC_MN, thrID_C,
263+
layoutA_MK, thrID_A,
264+
layoutB_NK, thrID_B);
265+
}
266+
267+
268+
template <class TiledCopy>
269+
void print_copy_content(const char* name, TiledCopy& copy) {
270+
printf("\n\\newpage\n");
271+
272+
auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN();
273+
auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN();
274+
275+
print_copy(name, copy,
276+
layoutS_MN, thrID_S,
277+
layoutD_MN, thrID_D);
278+
}
279+
280+
281+
void print_layouts_for_mma() {
282+
using namespace cute;
283+
using _X = cute::Underscore;
284+
285+
{
286+
auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
287+
Layout<Shape<_4,_1, _1>>{}, // AtomLayoutMNK
288+
Layout<Shape<_1,_2, _1>>{} // ValLayoutMNK
289+
);
290+
print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
291+
}
292+
293+
{
294+
auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
295+
Layout<Shape<_1,_1, _1>>{}, // AtomLayoutMNK
296+
Layout<Shape<_1,_2, _1>>{} // ValLayoutMNK
297+
);
298+
print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
299+
}
300+
301+
{
302+
auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
303+
Layout<Shape<_2,_1, _1>>{}, // AtomLayoutMNK
304+
Layout<Shape<_1,_2, _1>>{} // ValLayoutMNK
305+
);
306+
print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
307+
}
308+
309+
{
310+
auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
311+
Layout<Shape<_4,_1, _1>>{}, // AtomLayoutMNK
312+
Layout<Shape<_1,_1, _1>>{} // ValLayoutMNK
313+
);
314+
print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
315+
}
316+
317+
{
318+
auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
319+
Layout<Shape<_1,_1, _1>>{}, // AtomLayoutMNK
320+
Layout<Shape<_1,_1, _1>>{} // ValLayoutMNK
321+
);
322+
print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
323+
}
324+
325+
{
326+
auto tiled_mma = make_tiled_mma(SM75_16x8x8_F32F16F16F32_TN{},
327+
Layout<Shape<_1,_1, _1>>{}, // AtomLayoutMNK
328+
Layout<Shape<_1,_1, _1>>{} // ValLayoutMNK
329+
);
330+
print_mma_content("flash2: SM75_16x8x8_F32F16F16F32_TN", tiled_mma);
331+
}
332+
333+
{
334+
auto tiled_mma = make_tiled_mma(SM75_16x8x8_F32F16F16F32_TN{},
335+
Layout<Shape<_1,_1, _1>>{}, // AtomLayoutMNK
336+
Layout<Shape<_1,_2, _1>>{} // ValLayoutMNK
337+
);
338+
print_mma_content("flash2: SM75_16x8x8_F32F16F16F32_TN", tiled_mma);
339+
}
340+
341+
{
342+
auto tiled_mma = make_tiled_mma(SM75_16x8x8_F32F16F16F32_TN{},
343+
Layout<Shape<_4,_1, _1>>{}, // AtomLayoutMNK
344+
Layout<Shape<_1,_2, _2>>{} // ValLayoutMNK
345+
);
346+
print_mma_content("flash2: SM75_16x8x8_F32F16F16F32_TN", tiled_mma);
347+
}
348+
}
349+
350+
int main() {
351+
print_header();
352+
353+
print_layouts_for_mma();
354+
355+
print_footer();
356+
return 0;
357+
}
358+

0 commit comments

Comments
 (0)