Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 0 additions & 157 deletions src/kernels/attention/mha_tile.h

This file was deleted.

147 changes: 0 additions & 147 deletions src/kernels/attention/mla_tile.h

This file was deleted.

31 changes: 13 additions & 18 deletions src/kernels/attention/sm80_collective_epilogue.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,19 @@ struct Sm80CollectiveEpilogue {
template <class FrgTensor,
class TiledMma,
class TensorO,
class BlockCoordMNK,
class ProblemShapeMNK>
CUTE_DEVICE void operator()(const Params& /*params*/,
const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N)
TiledMma tiled_mma,
TensorO& gO, // (BLK_M, HEAD_DIM)
int tidx,
const BlockCoordMNK& block_coord_mnk,
const ProblemShapeMNK& problem_shape_mnk,
char* smem) {
class TensorCO,
class ResidueMNK>
CUTE_DEVICE void operator()(
const Params& /*params*/,
const FrgTensor& tOrAccO, // (MMA, MMA_M, MMA_N)
TiledMma tiled_mma,
TensorO& gO, // (BLK_M, HEAD_DIM)
const TensorCO& cO, // (BLK_M, HEAD_DIM) => (M, K)
int tidx,
const ResidueMNK& residue_mnk,
char* smem) {
static constexpr int kBlockM = get<0>(TileShape{});

const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk;
const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk;

// Smem
auto& ss = *reinterpret_cast<SharedStorage*>(smem);
// (BLK_M, HEAD_DIM)
Expand All @@ -106,9 +104,6 @@ struct Sm80CollectiveEpilogue {
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);

// (BLK_M, HEAD_DIM) -> (blk_m, head_dim)
auto cO = make_identity_tensor(Shape<BLK_M, HEAD_DIM>{});

auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K)
auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K)
// (CPY,CPY_M,CPY_K) -> (blk_m, head_dim)
Expand All @@ -117,9 +112,9 @@ struct Sm80CollectiveEpilogue {
// wait for smem copy done before gmem copy
__syncthreads();

auto max_coord = make_coord(q_packed_len - m_block_idx * kBlockM, head_dim);
const auto residue_mk = select<0, 2>(residue_mnk);
safe_copy</*EVEN_M=*/false, EVEN_K, /*ZFILL_M=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord);
gmem_tiled_copy_O, tOsO, tOgO, tOcO, residue_mk);
}
};
} // namespace llm
Loading