Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/kernels/attention/layout_convertor.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@ using namespace cute;
// Only works for TiledMMA (64x16x16) with SM80_16x8x16_F32F16F16F32_TN
struct LayoutConvertor {
// Convert fragment layout to rowcol layout for iterating
// (MMA=4, MMA_M, MMA_N) => ((2, MMA_M), (2, MMA_N))
// (MMA=4, MMA_M, MMA_N, ...) => ((2, MMA_M), (2, MMA_N), ...)
template <typename LayoutC>
CUTE_HOST_DEVICE static constexpr auto to_mn(const LayoutC& layout) {
auto l = logical_divide(layout, Shape<_2>{});
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)));
}
constexpr int R = LayoutC::rank;
static_assert(R >= 3, "Expected at least 3 modes in LayoutC.");

// (MMA=4, MMA_M, MMA_N, STEPS) => ((2, MMA_M), (2, MMA_N), STEPS)
template <typename LayoutC>
CUTE_HOST_DEVICE static constexpr auto to_mns(const LayoutC& layout) {
// ((2, 2), MMA_M, MMA_N, ...)
auto l = logical_divide(layout, Shape<_2>{});
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)),
get<3>(l));
// ((2, MMA_M), (2, MMA_N), ...)
if constexpr (R > 3) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)),
take<3, R>(l));
} else {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)));
}
}

// Convert fragment layout from gemm-I C to gemm-II A
Expand All @@ -36,4 +38,4 @@ struct LayoutConvertor {
}
};

} // namespace llm
} // namespace llm
40 changes: 18 additions & 22 deletions src/kernels/attention/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ struct Mask {
// Fragment type for alibi slopes
using FragmentT = decltype(make_tensor<float>(Int<ROWS_PER_THR>{}));

const int q_len_;
const int kv_len_;
const FastDivmod& group_size_;
const int sliding_window_;
Expand All @@ -24,46 +23,43 @@ struct Mask {
int kv_len,
const FastDivmod& group_size,
int sliding_window)
: q_len_(q_len),
kv_len_(kv_len),
: kv_len_(kv_len),
group_size_(group_size),
sliding_window_(sliding_window),
diagonal_offset_(kv_len - q_len) {}

// cS_mn: ((2, MMA_M), (2, MMA_N))
template <typename IdentityS>
CUTE_HOST_DEVICE void init_alibi(IdentityS& cS_mn,
int m_base_idx,
int kv_head_idx,
float sm_scale,
const float* alibi_slops_ptr) {
CUTE_HOST_DEVICE void init_alibi(
IdentityS& cS_mn, // ((2, MMA_M), (2, MMA_N)) => (M, N)
int kv_head_idx,
float sm_scale,
const float* alibi_slops_ptr) {
// copy alibi slopes to registers
CUTE_UNROLL
for (int i = 0; i < size<0>(cS_mn); ++i) {
const auto [m, n] = cS_mn(i, _0{});
const int q_packed_idx = m_base_idx + m;
const int q_packed_idx = get<0>(cS_mn(i, _0{}));
const int offset = q_packed_idx % group_size_;
const int head_idx = kv_head_idx * group_size_ + offset;
const int head_idx = (kv_head_idx * group_size_) + offset;
alibi_slopes_(i) = alibi_slops_ptr[head_idx] / sm_scale;
}
}

// rS_mn/cS_mn: ((2, MMA_M), (2, MMA_N))
template <bool OOB_MASK = true, typename FragmentS, typename IdentityS>
CUTE_HOST_DEVICE void apply(FragmentS& rS_mn,
IdentityS& cS_mn,
int m_base_idx,
int n_base_idx) const {
template <bool OOB_MASK, typename FragmentS, typename IdentityS>
CUTE_HOST_DEVICE void apply(
FragmentS& rS_mn, // ((2, MMA_M), (2, MMA_N))
IdentityS& cS_mn // ((2, MMA_M), (2, MMA_N)) => (M, N)
) const {
CUTE_UNROLL
for (int i = 0; i < size<0>(rS_mn); ++i) {
const auto alibi_slope = ALIBI ? alibi_slopes_(i) : 0.0f;
CUTE_UNROLL
for (int j = 0; j < size<1>(rS_mn); ++j) {
auto [m, n] = cS_mn(i, j);
const int q_packed_idx = m_base_idx + m;
const int kv_idx = n_base_idx + n;

const int q_idx = q_packed_idx / group_size_ + diagonal_offset_;
const auto [m, n] = cS_mn(i, j);
const int q_packed_idx = m;
const int kv_idx = n;
const int q_idx = (q_packed_idx / group_size_) + diagonal_offset_;

const bool out_of_boundary = [&]() {
if constexpr (OOB_MASK && LOCAL) {
Expand Down Expand Up @@ -93,4 +89,4 @@ struct Mask {
}
};

} // namespace llm
} // namespace llm
50 changes: 26 additions & 24 deletions src/kernels/attention/sm80_collective_mha.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ struct Sm80CollectiveMha {
class TensorK,
class TensorV,
class TensorCKV,
class TensorCMN,
class FrgTensor,
class Softmax,
class BlockCoordMNK,
class BlockCoord,
class ResidueMNK>
CUTE_DEVICE void operator()(
const Params& params,
Expand All @@ -159,10 +160,11 @@ struct Sm80CollectiveMha {
const TensorK& gK, // (BLK_N, HEAD_DIM, n)
const TensorV& gV, // (BLK_N, HEAD_DIM, n)
const TensorCKV& cKV, // (BLK_N, HEAD_DIM, n) => (N, K)
const TensorCMN& cMN, // (BLK_M, BLK_N, n) => (M, N)
FrgTensor& tOrO, // (MMA, MMA_M, MMA_N)
Softmax& softmax,
int tidx,
const BlockCoordMNK& block_coord_mnk,
const BlockCoord& blk_coord,
const ResidueMNK& residue_mnk, // (M, N, K)
char* smem) {
static_assert(is_rmem<FrgTensor>::value,
Expand All @@ -174,7 +176,8 @@ struct Sm80CollectiveMha {
static constexpr int kBlockM = get<0>(TileShape{});
static constexpr int kBlockN = get<1>(TileShape{});

const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk;
const int q_idx = get<0>(blk_coord);
const int kv_head_idx = get<1>(blk_coord);
const int q_packed_len = get<0>(residue_mnk);
const int kv_len = get<1>(residue_mnk);

Expand Down Expand Up @@ -342,7 +345,7 @@ struct Sm80CollectiveMha {
auto tOrO_mn =
make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout()));

const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len;
const int diagonal = q_idx + kv_len - q_len;
// process kv in range: [kv_idx_min, kv_idx_max)
const int kv_idx_min = std::max(0, diagonal - sliding_window);
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
Expand Down Expand Up @@ -385,31 +388,29 @@ struct Sm80CollectiveMha {
constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1;
const int n_blocks = n_block_max - n_block_min;

// attention score accumulator, (MMA,MMA_M,MMA_N)
// attention score accumulator, (MMA, MMA_M, MMA_N)
auto tSrS = partition_fragment_C(tiled_mma, Shape<BLK_M, BLK_N>{});
// ((2, MMA_M), (2, MMA_N))
auto tSrS_mn =
make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout()));

// identity tensor for score accumulator
auto tScS =
thr_mma.partition_C(make_identity_tensor(Shape<BLK_M, BLK_N>{}));
auto tScS_mn =
make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout()));
// (MMA, MMA_M, MMA_N, n)
auto tScMN = thr_mma.partition_C(cMN);
// ((2, MMA_M), (2, MMA_N), n) => (M, N)
auto tScMN_mn =
make_tensor(tScMN.data(), LayoutConvertor::to_mn(tScMN.layout()));

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Mask = Mask<kRowsPerThr, ALIBI, LOCAL>;
Mask mask(q_len, kv_len, group_size, sliding_window);
if constexpr (ALIBI) {
mask.init_alibi(tScS_mn,
m_block_idx * kBlockM,
kv_head_idx,
sm_scale,
params.alibi_slopes_ptr);
const auto tScS_mn = tScMN_mn(_, _, _0{});
mask.init_alibi(tScS_mn, kv_head_idx, sm_scale, params.alibi_slopes_ptr);
}

CUTE_NO_UNROLL
for (int i = 0; i < n_blocks; ++i) {
const int n_block_idx = n_block_max - 1 - i;
const int ni = n_block_max - 1 - i;
clear(tSrS);

// wait key, queue: [q, k] => []
Expand All @@ -418,9 +419,9 @@ struct Sm80CollectiveMha {

// produce value, [] => [v]
if (i == 0) {
produce_value(n_block_idx);
produce_value(ni);
} else {
produce_value_no_oob(n_block_idx);
produce_value_no_oob(ni);
}
cp_async_fence();

Expand All @@ -435,18 +436,19 @@ struct Sm80CollectiveMha {
apply_logits_soft_cap(tSrS);
}

// apply mask
// ((2, MMA_M), (2, MMA_N)) => (M, N)
const auto tScS_mn = tScMN_mn(_, _, ni);
if (i < n_oob_mask) {
mask.template apply</*OOB_MASK=*/true>(
tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN);
mask.template apply</*OOB_MASK=*/true>(tSrS_mn, tScS_mn);
} else {
mask.template apply</*OOB_MASK=*/false>(
tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN);
mask.template apply</*OOB_MASK=*/false>(tSrS_mn, tScS_mn);
}
softmax.rescale(tSrS_mn, tOrO_mn);

// produce next key: [] => [k]
if (n_block_idx > n_block_min) {
produce_key_no_oob(n_block_idx - 1);
if (ni > n_block_min) {
produce_key_no_oob(ni - 1);
}
cp_async_fence();

Expand Down
38 changes: 20 additions & 18 deletions src/kernels/attention/sm80_collective_mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ struct Sm80CollectiveMla {
class TensorCQR,
class TensorKR,
class TensorCKR,
class TensorCMN,
class FrgTensor,
class Softmax,
class BlockCoordMNK,
class BlockCoord,
class ResidueMNK,
class RopeResidueMNK>
CUTE_DEVICE void operator()(
Expand All @@ -259,10 +260,11 @@ struct Sm80CollectiveMla {
const TensorCQR& cQ_rope, // (BLK_M, ROPE_HEAD_DIM) =>(M, K)
const TensorKR& gK_rope, // (BLK_N, HEAD_DIM, n)
const TensorCKR& cK_rope, // (BLK_N, HEAD_DIM, n) => (N, K)
const TensorCMN& cMN, // (BLK_M, BLK_N, n) => (M, N)
FrgTensor& tOrO, // (BLK_N, ROPE_HEAD_DIM, n)
Softmax& softmax,
int tidx,
const BlockCoordMNK& block_coord_mnk,
const BlockCoord& blk_coord,
const ResidueMNK& residue_mnk,
const RopeResidueMNK& rope_residue_mnk,
char* smem) {
Expand All @@ -279,12 +281,11 @@ struct Sm80CollectiveMla {
static constexpr int kBlockN = get<1>(TileShape{});
static constexpr int kBlockK = get<2>(TileShape{});

const int m_block_idx = get<1>(block_coord_mnk);
const int q_idx = get<0>(blk_coord);
const int q_packed_len = get<0>(residue_mnk);
const int kv_len = get<1>(residue_mnk);

const auto& group_size = params.group_size;

const int q_len = q_packed_len / group_size;

// Construct shared memory tiles
Expand Down Expand Up @@ -586,16 +587,13 @@ struct Sm80CollectiveMla {
cute::copy(smem_tiled_copy_S, tCrS, tCsS);
};

// output accumulator: (MMA,MMA_M,MMA_K,k)
// auto tOrO =
// partition_fragment_C(tiled_mma_pv, Shape<BLK_M, BLK_K, STEPS>{});
// clear(tOrO);
// (MMA,MMA_M,MMA_K,k) => ((2, MMA_M), (2, MMA_K), k)
auto tOrO_mn =
make_tensor(tOrO.data(), LayoutConvertor::to_mns(tOrO.layout()));
make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout()));

const int n_block_min = 0;
// process kv in range: [0, kv_idx_max)
const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len;
const int diagonal = q_idx + kv_len - q_len;
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN);

Expand Down Expand Up @@ -649,14 +647,17 @@ struct Sm80CollectiveMla {
}

// ############### Mainloop ###############
// attention score accumulator, (MMA,MMA_M,MMA_N)
// attention score accumulator, (MMA, MMA_M, MMA_N)
auto tSrS = partition_fragment_C(tiled_mma_qk, Shape<BLK_M, BLK_N>{});
auto tScS =
thr_mma_qk.partition_C(make_identity_tensor(Shape<BLK_M, BLK_N>{}));
// ((2, MMA_M), (2, MMA_N))
auto tSrS_mn =
make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout()));
auto tScS_mn =
make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout()));

// (MMA, MMA_M, MMA_N, n) => (M, N)
auto tScMN = thr_mma_qk.partition_C(cMN);
// ((2, MMA_M), (2, MMA_N), n) => (M, N)
auto tScMN_mn =
make_tensor(tScMN.data(), LayoutConvertor::to_mn(tScMN.layout()));

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Mask = Mask<kRowsPerThr, /*ALIBI=*/false, /*LOCAL=*/false>;
Expand Down Expand Up @@ -688,11 +689,12 @@ struct Sm80CollectiveMla {
}

// apply mask
// ((2, MMA_M), (2, MMA_N)) => (M, N)
const auto tScS_mn = tScMN_mn(_, _, ni);
if (i < n_oob_mask) {
mask.apply(tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN);
mask.apply</*OOB_MASK=*/true>(tSrS_mn, tScS_mn);
} else {
mask.apply</*OOB_MASK=*/false>(
tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN);
mask.apply</*OOB_MASK=*/false>(tSrS_mn, tScS_mn);
}

softmax.rescale(tSrS_mn, tOrO_mn, reduce_rowmax);
Expand Down
Loading