Skip to content

Commit 6c811c9

Browse files
authored
[swizzle] add padding -> swizzle layout tools🎉 (#198)
* Update README.md * add pad -> swizzle layout tools
1 parent fd993a9 commit 6c811c9

File tree

4 files changed

+123
-31
lines changed

4 files changed

+123
-31
lines changed

kernels/flash-attn/tools/print_swizzle_layout.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,
4040

4141

4242
def print_smem_swizzle_layout(rows: int = 16,
43-
logical_col_stride: int = 16,
44-
num_elems_per_128b: int = 8,
45-
show_logical_col_id: bool = False,
46-
use_logical_col_stride: bool = False):
43+
logical_col_stride: int = 16,
44+
num_elems_per_128b: int = 8,
45+
smem_pading: int = 0,
46+
show_logical_col_id: bool = False,
47+
use_logical_col_stride: bool = False):
4748
# ----------------------------------------------------------------
4849
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
4950
# [INFO] For logical_col_stride > 16, we have to permute the |
@@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
9596
# ----------------------------------------------------------------
9697
str_len = 0
9798
total_banks = 0
99+
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
98100
# 4 bytes per bank
99101
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
100102
if use_logical_col_stride:
101103
banks_per_col = int((logical_col_stride * 2) / 4)
102104
if logical_col_stride > 16:
103105
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
106+
if smem_pading == 8:
107+
banks_per_col += 4
108+
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")
109+
104110
banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
105111
for i in range(rows):
106112
layout_str_len = 0
@@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
139145
num_elems_per_128b)
140146
logical_col_ids.append(j)
141147
smem_layout_col_ids.append(layout_j)
148+
142149
smem_layout_str = f"|row {i:<2}|"
150+
151+
r = 0
143152
for c, l in zip(logical_col_ids, smem_layout_col_ids):
144-
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
145-
show_logical_col_id else f"{l:<2}"),
146-
sep=" ",
147-
width=max_bank_str_len-1,
148-
return_str=True) + "|"
153+
smem_layout_str += pretty_print_line(
154+
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
155+
sep=" ",
156+
width=(max_bank_str_len-1),
157+
return_str=True
158+
) + "|"
159+
r += 1
160+
if logical_col_stride >= 16:
161+
if smem_pading == 8 and (r > 1 and r % 2 == 0):
162+
smem_layout_str += pretty_print_line(
163+
(f"pad"),
164+
sep=" ", width=max_bank_str_len-1,
165+
return_str=True
166+
) + "|"
167+
else:
168+
if smem_pading == 8:
169+
smem_layout_str += pretty_print_line(
170+
(f"pad"),
171+
sep=" ", width=max_bank_str_len-1,
172+
return_str=True
173+
) + "|"
174+
149175
layout_str_len = len(smem_layout_str)
150176
str_len = max(layout_str_len, banks_str_len)
151177

@@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
172198
def get_args():
173199
parser = argparse.ArgumentParser()
174200
parser.add_argument("--rows", type=int, default=16)
201+
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
175202
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
176203
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
177204
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
@@ -186,6 +213,7 @@ def get_args():
186213
print_smem_swizzle_layout(rows=args.rows,
187214
logical_col_stride=args.logical_col_stride,
188215
num_elems_per_128b=args.num_elems_per_128b,
216+
smem_pading=args.smem_padding,
189217
show_logical_col_id=args.show_logical_col_id,
190218
use_logical_col_stride=args.use_logical_col_stride)
191219

kernels/hgemm/tools/print_swizzle_layout.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,
4040

4141

4242
def print_smem_swizzle_layout(rows: int = 16,
43-
logical_col_stride: int = 16,
44-
num_elems_per_128b: int = 8,
45-
show_logical_col_id: bool = False,
46-
use_logical_col_stride: bool = False):
43+
logical_col_stride: int = 16,
44+
num_elems_per_128b: int = 8,
45+
smem_pading: int = 0,
46+
show_logical_col_id: bool = False,
47+
use_logical_col_stride: bool = False):
4748
# ----------------------------------------------------------------
4849
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
4950
# [INFO] For logical_col_stride > 16, we have to permute the |
@@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
9596
# ----------------------------------------------------------------
9697
str_len = 0
9798
total_banks = 0
99+
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
98100
# 4 bytes per bank
99101
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
100102
if use_logical_col_stride:
101103
banks_per_col = int((logical_col_stride * 2) / 4)
102104
if logical_col_stride > 16:
103105
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
106+
if smem_pading == 8:
107+
banks_per_col += 4
108+
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")
109+
104110
banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
105111
for i in range(rows):
106112
layout_str_len = 0
@@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
139145
num_elems_per_128b)
140146
logical_col_ids.append(j)
141147
smem_layout_col_ids.append(layout_j)
148+
142149
smem_layout_str = f"|row {i:<2}|"
150+
151+
r = 0
143152
for c, l in zip(logical_col_ids, smem_layout_col_ids):
144-
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
145-
show_logical_col_id else f"{l:<2}"),
146-
sep=" ",
147-
width=max_bank_str_len-1,
148-
return_str=True) + "|"
153+
smem_layout_str += pretty_print_line(
154+
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
155+
sep=" ",
156+
width=(max_bank_str_len-1),
157+
return_str=True
158+
) + "|"
159+
r += 1
160+
if logical_col_stride >= 16:
161+
if smem_pading == 8 and (r > 1 and r % 2 == 0):
162+
smem_layout_str += pretty_print_line(
163+
(f"pad"),
164+
sep=" ", width=max_bank_str_len-1,
165+
return_str=True
166+
) + "|"
167+
else:
168+
if smem_pading == 8:
169+
smem_layout_str += pretty_print_line(
170+
(f"pad"),
171+
sep=" ", width=max_bank_str_len-1,
172+
return_str=True
173+
) + "|"
174+
149175
layout_str_len = len(smem_layout_str)
150176
str_len = max(layout_str_len, banks_str_len)
151177

@@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
172198
def get_args():
173199
parser = argparse.ArgumentParser()
174200
parser.add_argument("--rows", type=int, default=16)
201+
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
175202
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
176203
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
177204
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
@@ -186,6 +213,7 @@ def get_args():
186213
print_smem_swizzle_layout(rows=args.rows,
187214
logical_col_stride=args.logical_col_stride,
188215
num_elems_per_128b=args.num_elems_per_128b,
216+
smem_pading=args.smem_padding,
189217
show_logical_col_id=args.show_logical_col_id,
190218
use_logical_col_stride=args.use_logical_col_stride)
191219

kernels/swizzle/hgemm_mma_swizzle.cu

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
520520
constexpr int MMA_TILE_N = 4;
521521
constexpr int WARP_TILE_M = 4;
522522
constexpr int WARP_TILE_N = 4;
523-
// bank conflicts free via pad = 8, 拒绝幻想,相信profile
523+
// bank conflicts free via pad = 8.
524524
// ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin
525525
// ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin
526526
// constexpr int A_PAD = 8;
@@ -541,6 +541,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
541541
);
542542
}
543543

544+
template <const int B_PAD = 8>
544545
void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
545546
half* a, half* b, half* c, int M, int N, int K) {
546547
constexpr int MMA_M = 16;
@@ -551,7 +552,7 @@ void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
551552
constexpr int WARP_TILE_M = 4;
552553
constexpr int WARP_TILE_N = 4;
553554
constexpr int A_PAD = 0;
554-
constexpr int B_PAD = 8;
555+
// B_PAD = 8, bank conflicts free via pad = 8.
555556
constexpr int NUM_THREADS= (
556557
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
557558
dim3 block(NUM_THREADS);
@@ -644,9 +645,16 @@ int main(int argc, char *argv[]) {
644645
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
645646
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
646647
printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);
648+
649+
printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 0\n");
650+
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<0>,
651+
M, N, K, W, R);
652+
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
653+
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
654+
printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);
647655

648-
printf("\nALGO = HGEMM mma2x4_warp4x4 + SMEM SWIZZLE\n");
649-
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle,
656+
printf("\nALGO = HGEMM mma2x4_warp4x4 + A SMEM SWIZZLE + B_PAD 8\n");
657+
avg_sec = perf_gemm<half>(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle<8>,
650658
M, N, K, W, R);
651659
avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
652660
printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);

kernels/swizzle/print_swizzle_layout.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ def swizzle_permuted_j(i: int,
4040

4141

4242
def print_smem_swizzle_layout(rows: int = 16,
43-
logical_col_stride: int = 16,
44-
num_elems_per_128b: int = 8,
45-
show_logical_col_id: bool = False,
46-
use_logical_col_stride: bool = False):
43+
logical_col_stride: int = 16,
44+
num_elems_per_128b: int = 8,
45+
smem_pading: int = 0,
46+
show_logical_col_id: bool = False,
47+
use_logical_col_stride: bool = False):
4748
# ----------------------------------------------------------------
4849
# [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
4950
# [INFO] For logical_col_stride > 16, we have to permute the |
@@ -95,12 +96,17 @@ def print_smem_swizzle_layout(rows: int = 16,
9596
# ----------------------------------------------------------------
9697
str_len = 0
9798
total_banks = 0
99+
assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
98100
# 4 bytes per bank
99101
banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
100102
if use_logical_col_stride:
101103
banks_per_col = int((logical_col_stride * 2) / 4)
102104
if logical_col_stride > 16:
103105
print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
106+
if smem_pading == 8:
107+
banks_per_col += 4
108+
print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}")
109+
104110
banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
105111
for i in range(rows):
106112
layout_str_len = 0
@@ -139,13 +145,33 @@ def print_smem_swizzle_layout(rows: int = 16,
139145
num_elems_per_128b)
140146
logical_col_ids.append(j)
141147
smem_layout_col_ids.append(layout_j)
148+
142149
smem_layout_str = f"|row {i:<2}|"
150+
151+
r = 0
143152
for c, l in zip(logical_col_ids, smem_layout_col_ids):
144-
smem_layout_str += pretty_print_line((f"{c:>2}:{l:<2}" if
145-
show_logical_col_id else f"{l:<2}"),
146-
sep=" ",
147-
width=max_bank_str_len-1,
148-
return_str=True) + "|"
153+
smem_layout_str += pretty_print_line(
154+
(f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
155+
sep=" ",
156+
width=(max_bank_str_len-1),
157+
return_str=True
158+
) + "|"
159+
r += 1
160+
if logical_col_stride >= 16:
161+
if smem_pading == 8 and (r > 1 and r % 2 == 0):
162+
smem_layout_str += pretty_print_line(
163+
(f"pad"),
164+
sep=" ", width=max_bank_str_len-1,
165+
return_str=True
166+
) + "|"
167+
else:
168+
if smem_pading == 8:
169+
smem_layout_str += pretty_print_line(
170+
(f"pad"),
171+
sep=" ", width=max_bank_str_len-1,
172+
return_str=True
173+
) + "|"
174+
149175
layout_str_len = len(smem_layout_str)
150176
str_len = max(layout_str_len, banks_str_len)
151177

@@ -172,6 +198,7 @@ def print_smem_swizzle_layout(rows: int = 16,
172198
def get_args():
173199
parser = argparse.ArgumentParser()
174200
parser.add_argument("--rows", type=int, default=16)
201+
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
175202
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
176203
parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64)
177204
parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true")
@@ -186,6 +213,7 @@ def get_args():
186213
print_smem_swizzle_layout(rows=args.rows,
187214
logical_col_stride=args.logical_col_stride,
188215
num_elems_per_128b=args.num_elems_per_128b,
216+
smem_pading=args.smem_padding,
189217
show_logical_col_id=args.show_logical_col_id,
190218
use_logical_col_stride=args.use_logical_col_stride)
191219

0 commit comments

Comments
 (0)