Skip to content

Commit 951225f

Browse files
committed
kernel: use cp_async_zfill instead of cute::clear for oob handling
1 parent 65f3f9c commit 951225f

File tree

4 files changed

+68
-35
lines changed

4 files changed

+68
-35
lines changed

src/kernels/attention/attention_kernel_sm80.cuh

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
143143
auto produce_q = [&]() {
144144
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
145145
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
146-
safe_copy<EVEN_K,
147-
/*EVEN_MN=*/false,
148-
/*ZERO_FILL_MN=*/true,
149-
/*ZERO_FILL_K=*/true>(
146+
safe_copy</*EVEN_MN=*/false, EVEN_K>(
150147
gmem_tiled_copy_Q,
151148
tQgQ,
152149
tQsQ,
@@ -159,10 +156,9 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
159156
auto produce_k = [&](int ni) {
160157
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
161158
// skip zfill_mn for k since mask will mask out oob with -inf
162-
safe_copy<EVEN_K,
163-
/*EVEN_MN=*/false,
164-
/*ZERO_FILL_MN=*/false,
165-
/*ZERO_FILL_K=*/true>(
159+
safe_copy</*EVEN_MN=*/false,
160+
EVEN_K,
161+
/*ZERO_FILL_MN=*/false>(
166162
gmem_tiled_copy_KV,
167163
tKgK,
168164
tKsK,
@@ -174,10 +170,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
174170
auto produce_v = [&](int ni) {
175171
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
176172
// skipping ZFILL_MN for v may cause nan issue
177-
safe_copy<EVEN_K,
178-
/*EVEN_MN=*/false,
179-
/*ZERO_FILL_MN=*/true,
180-
/*ZERO_FILL_K=*/true>(
173+
safe_copy</*EVEN_MN=*/false, EVEN_K>(
181174
gmem_tiled_copy_KV,
182175
tVgV,
183176
tVsV,
@@ -302,8 +295,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
302295

303296
// wait for smem copy done before gmem copy
304297
__syncthreads();
305-
safe_copy<EVEN_K,
306-
/*EVEN_MN=*/false,
298+
safe_copy</*EVEN_MN=*/false,
299+
EVEN_K,
307300
/*ZERO_FILL_MN=*/false,
308301
/*ZERO_FILL_K=*/false>(
309302
gmem_tiled_copy_O,

src/kernels/attention/attention_traits_sm80.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ struct AttentionTraitsSM80 {
9797
// Tiled copy for QKV
9898
// g2s tiled copy for q
9999
using GmemTiledCopyQ = decltype(make_tiled_copy(
100-
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DType>{},
100+
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>, DType>{},
101101
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
102102
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
103103
));
104104

105105
// g2s tiled copy for kv
106106
// TODO: choose based on BLK_K and kv cache type
107107
using GmemTiledCopyKV = decltype(make_tiled_copy(
108-
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, KV_DType>{},
108+
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>, KV_DType>{},
109109
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
110110
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
111111
));

src/kernels/attention/cute_extensions.cuh

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,62 @@ CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a,
2222
return elem_less(get<I>(a), get<I>(b));
2323
}
2424

25-
template <bool EVEN_K,
26-
bool EVEN_MN,
27-
bool ZERO_FILL_MN,
28-
bool ZERO_FILL_K,
29-
class TiledCopy,
25+
template <class Copy_Atom, class TensorS, class TensorD>
26+
CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom,
27+
const TensorS& src,
28+
TensorD&& dst) {
29+
CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch.");
30+
31+
auto has_with_bool = cute::is_valid(
32+
[](auto t) -> void_t<decltype(declval<typename decltype(t)::Traits>()
33+
.with(true))> {},
34+
copy_atom);
35+
if constexpr (has_with_bool) {
36+
constexpr int R = TensorD::rank;
37+
if constexpr (R == 1) { // Dispatch the copy
38+
copy_atom.with(false).call(src, dst);
39+
} else { // Loop over all but the first mode
40+
Tensor src_v = group_modes<1, R>(src);
41+
Tensor dst_v = group_modes<1, R>(dst);
42+
CUTE_UNROLL
43+
for (int i = 0; i < size<1>(dst_v); ++i) {
44+
copy_atom.with(false).call(src_v(_, i), dst_v(_, i));
45+
}
46+
}
47+
} else {
48+
// just call clear if no with method
49+
clear(dst);
50+
}
51+
}
52+
53+
template <class Copy_Atom, class TensorS, class TensorD>
54+
CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom,
55+
const TensorS& src,
56+
TensorD& dst) {
57+
zfill(copy_atom, src, dst);
58+
}
59+
60+
template <bool EVEN_MN,
61+
bool EVEN_K,
62+
bool ZFILL_MN = true,
63+
bool ZFILL_K = true,
64+
class CopyAtom,
65+
class TV,
66+
class Tiler,
3067
class TensorS,
3168
class TensorD,
3269
class TensorC,
3370
class Coord>
3471
CUTE_HOST_DEVICE void safe_copy(
35-
const TiledCopy& tiled_copy,
72+
const TiledCopy<CopyAtom, TV, Tiler>& tiled_copy,
3673
const TensorS& src, // (CPY, CPY_M/N, CPY_K)
3774
TensorD& dst, // (CPY, CPY_M/N, CPY_K)
3875
const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k)
3976
const Coord& max_coord // max_coord(blk_m/n, blk_k)
4077
) {
78+
CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch.");
79+
auto copy_atom = static_cast<const CopyAtom&>(tiled_copy);
80+
4181
if constexpr (!EVEN_MN && !EVEN_K) {
4282
// handle both m/n and k oob
4383
CUTE_UNROLL
@@ -46,16 +86,16 @@ CUTE_HOST_DEVICE void safe_copy(
4686
CUTE_UNROLL
4787
for (int ki = 0; ki < size<2>(src); ++ki) {
4888
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
49-
copy(tiled_copy, src(_, mi, ki), dst(_, mi, ki));
89+
copy(copy_atom, src(_, mi, ki), dst(_, mi, ki));
5090
} else {
51-
if constexpr (ZERO_FILL_K) {
52-
clear(dst(_, mi, ki));
91+
if constexpr (ZFILL_K) {
92+
zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki));
5393
}
5494
}
5595
}
5696
} else {
57-
if constexpr (ZERO_FILL_MN) {
58-
clear(dst(_, mi, _));
97+
if constexpr (ZFILL_MN) {
98+
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
5999
}
60100
}
61101
}
@@ -64,10 +104,10 @@ CUTE_HOST_DEVICE void safe_copy(
64104
CUTE_UNROLL
65105
for (int mi = 0; mi < size<1>(src); ++mi) {
66106
if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) {
67-
copy(tiled_copy, src(_, mi, _), dst(_, mi, _));
107+
copy(copy_atom, src(_, mi, _), dst(_, mi, _));
68108
} else {
69-
if constexpr (ZERO_FILL_MN) {
70-
clear(dst(_, mi, _));
109+
if constexpr (ZFILL_MN) {
110+
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
71111
}
72112
}
73113
}
@@ -76,16 +116,16 @@ CUTE_HOST_DEVICE void safe_copy(
76116
CUTE_UNROLL
77117
for (int ki = 0; ki < size<2>(src); ++ki) {
78118
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
79-
copy(tiled_copy, src(_, _, ki), dst(_, _, ki));
119+
copy(copy_atom, src(_, _, ki), dst(_, _, ki));
80120
} else {
81-
if constexpr (ZERO_FILL_K) {
82-
clear(dst(_, _, ki));
121+
if constexpr (ZFILL_K) {
122+
zfill(copy_atom, src(_, _, ki), dst(_, _, ki));
83123
}
84124
}
85125
}
86126
} else {
87127
// no oob, just copy
88-
copy(tiled_copy, src, dst);
128+
copy(copy_atom, src, dst);
89129
}
90130
}
91131

src/kernels/attention/tools/attention_traits_viewer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void test_attn_traits() {
139139
auto thr_mma = tiled_mma.get_slice(0);
140140
// (MMA, MMA_N, MMA_K)
141141
// ((_2,_2),_8,_4):((_1,_2),_16,_4)
142-
auto tSrK = partition_fragment_B(thr_mma, sK);
142+
auto tSrK = thr_mma.partition_fragment_B(sK);
143143
print(tSrK);print("\n");
144144

145145
auto tSrK_fp8 = make_fragment_like<cute::int8_t>(tSrK);

0 commit comments

Comments
 (0)