Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/kernels/attention/sm80_collective_epilogue.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct Sm80CollectiveEpilogue {
char* smem) {
static constexpr int kBlockM = get<0>(TileShape{});

const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk;
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk;
const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk;

// Smem
Expand Down
2 changes: 1 addition & 1 deletion src/kernels/attention/sm80_collective_mha.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ struct Sm80CollectiveMha {
static constexpr int kBlockM = get<0>(TileShape{});
static constexpr int kBlockN = get<1>(TileShape{});

const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk;
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk;
const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk;

const int sliding_window = LOCAL ? params.sliding_window : kv_len;
Expand Down
143 changes: 78 additions & 65 deletions src/kernels/attention/sm80_kernel_mha.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ namespace llm {

using namespace cute;

template <class CollectiveMainloop_, class CollectiveEpilogue_>
template <class CollectiveMainloop_,
class CollectiveEpilogue_,
class TileScheduler_>
class Sm80KernelMha {
public:
using CollectiveMainloop = CollectiveMainloop_;
using CollectiveEpilogue = CollectiveEpilogue_;
using TileScheduler = TileScheduler_;

using TiledMma = typename CollectiveMainloop::TiledMma;

Expand All @@ -39,45 +42,22 @@ class Sm80KernelMha {
// Kernel params
using MainloopParams = typename CollectiveMainloop::Params;
using EpilogueParams = typename CollectiveEpilogue::Params;
using TileSchedulerParams = typename TileScheduler::Params;

// returns grid and block shape for kernel launch
using TileSchedulerArgs = typename TileScheduler::Arguments;
static dim3 get_grid_shape(TileSchedulerArgs const& args) {
return TileScheduler::get_grid_shape(args);
}
static dim3 get_block_shape() { return kMmaThreads; }

template <class Params>
CUTE_DEVICE void operator()(const Params& params, char* smem) {
CUTE_DEVICE void operator()(const Params& params,
const TileSchedulerParams& scheduler_params,
char* smem) {
CollectiveMainloop mha;
CollectiveEpilogue epilogue;

const auto tidx = threadIdx.x;

// block coord
const int m_block_idx = blockIdx.x;
const int batch_idx = blockIdx.y;
const int kv_head_idx = blockIdx.z;
auto block_coord_mnk = make_coord(m_block_idx, batch_idx, kv_head_idx);

// (q_packed_len, HEAD_DIM)
MHATile<Params> tile(params, batch_idx, kv_head_idx);
auto [Q, O] = tile.template get_qo_tile<Element>();
// (kv_len, HEAD_DIM)
auto [K, V] = tile.template get_kv_tile<Element>();

// problem shape
const int q_packed_len = size<0>(Q);
const int kv_len = size<0>(K);
const int head_dim = params.head_dim;
auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim);

if (m_block_idx * kBlockM >= q_packed_len) {
// m out of bound, return
return;
}

// (BLK_M, HEAD_DIM)
Tensor gQ =
local_tile(Q, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
Tensor gO =
local_tile(O, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
// (BLK_N, HEAD_DIM, n)
Tensor gK = local_tile(K, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
Tensor gV = local_tile(V, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
TileScheduler scheduler(scheduler_params);

// construct params
MainloopParams mainloop_params{params.sliding_window,
Expand All @@ -88,35 +68,68 @@ class Sm80KernelMha {
params.group_size};
EpilogueParams epilogue_params;

TiledMma tiled_mma;
// accumulator: MMA,MMA_M,MMA_K)
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<BLK_M, HEAD_DIM>{});
clear(tOrAccO);

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO);
OnlineSoftmax<kRowsPerThr> softmax(params.sm_scale_log2);

// mainloop
mha(mainloop_params,
gQ,
gK,
gV,
tOrAccO,
softmax,
tidx,
block_coord_mnk,
problem_shape_mnk,
smem);

// epilogue
epilogue(epilogue_params,
tOrAccO,
tiled_mma,
gO,
tidx,
block_coord_mnk,
problem_shape_mnk,
smem);
// process each block
for (const auto block_coord : scheduler) {
// block coord: (batch_idx, m_block_idx, kv_head_idx)
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord;
const auto tidx = threadIdx.x;

// (q_packed_len, HEAD_DIM)
MHATile<Params> tile(params, batch_idx, kv_head_idx);
auto [Q, O] = tile.template get_qo_tile<Element>();
// (kv_len, HEAD_DIM)
auto [K, V] = tile.template get_kv_tile<Element>();

// problem shape
const int q_packed_len = size<0>(Q);
const int kv_len = size<0>(K);
const int head_dim = params.head_dim;
auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim);

if (m_block_idx * kBlockM >= q_packed_len) {
// m out of bound, skip this block
continue;
}

// (BLK_M, HEAD_DIM)
Tensor gQ = local_tile(
Q, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
Tensor gO = local_tile(
O, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
// (BLK_N, HEAD_DIM, n)
Tensor gK = local_tile(K, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
Tensor gV = local_tile(V, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));

TiledMma tiled_mma;
// accumulator: MMA,MMA_M,MMA_K)
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<BLK_M, HEAD_DIM>{});
clear(tOrAccO);

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO);
OnlineSoftmax<kRowsPerThr> softmax(params.sm_scale_log2);

// mainloop
mha(mainloop_params,
gQ,
gK,
gV,
tOrAccO,
softmax,
tidx,
block_coord,
problem_shape_mnk,
smem);

// epilogue
epilogue(epilogue_params,
tOrAccO,
tiled_mma,
gO,
tidx,
block_coord,
problem_shape_mnk,
smem);
}
}
};

Expand Down
26 changes: 19 additions & 7 deletions src/kernels/attention/sm80_mha_launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
#include "sm80_collective_epilogue.cuh"
#include "sm80_collective_mha.cuh"
#include "sm80_kernel_mha.cuh"
#include "tile_scheduler.cuh"

namespace llm {

namespace detail {
/// Generic kernel template.
template <typename Operator, typename Params>
__global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel(
__grid_constant__ const Params params) {
__grid_constant__ const Params params,
__grid_constant__ const typename Operator::TileSchedulerParams
scheduler_params) {
extern __shared__ char smem[];
Operator op;
op(params, smem);
op(params, scheduler_params, smem);
}
} // namespace detail

Expand Down Expand Up @@ -61,7 +64,17 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) {
using CollectiveEpilogue =
Sm80CollectiveEpilogue<TileShape, Dtype, HEAD_DIM, EVEN_K>;

using AttnKernel = Sm80KernelMha<CollectiveMainloop, CollectiveEpilogue>;
// TODO: support persistent kernels
using TileScheduler = SingleTileScheduler;

const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M);
typename TileScheduler::Arguments scheduler_args{
batch_size, m_blocks, n_kv_heads};
auto scheduler_params =
TileScheduler::to_underlying_arguments(scheduler_args);

using AttnKernel =
Sm80KernelMha<CollectiveMainloop, CollectiveEpilogue, TileScheduler>;

auto mha_kernel = detail::device_kernel<AttnKernel, Params>;

Expand All @@ -71,11 +84,10 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) {
mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}

// TODO: support persistent kernels
dim3 grid(cute::ceil_div(max_q_packed_len, BLK_M), batch_size, n_kv_heads);
dim3 block = AttnKernel::kMmaThreads;
const dim3 grid = AttnKernel::get_grid_shape(scheduler_args);
const dim3 block = AttnKernel::get_block_shape();

mha_kernel<<<grid, block, smem_size, stream>>>(params);
mha_kernel<<<grid, block, smem_size, stream>>>(params, scheduler_params);
// TODO: check launch status
}

Expand Down
63 changes: 63 additions & 0 deletions src/kernels/attention/tile_scheduler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <cuda.h>
#include <cuda_runtime.h>

#include <cute/layout.hpp>
#include <cute/tensor.hpp>

namespace llm {

class SingleTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
int batch_size = 0;
int m_blocks = 0;
int n_kv_heads = 0;
};
static dim3 get_grid_shape(Arguments const& args) {
return {(uint32_t)args.batch_size,
(uint32_t)args.m_blocks,
(uint32_t)args.n_kv_heads};
}

// Device side kernel params
using Params = Arguments;
static Params to_underlying_arguments(const Arguments& args) { return args; }

// End Iterator tag
class EndIterator {};
class Iterator {
public:
CUTE_DEVICE
Iterator() = default;

CUTE_DEVICE
dim3 operator*() const { return blockIdx; }

CUTE_DEVICE
Iterator& operator++() {
valid_ = false;
return *this;
}

// compare against end iterator
CUTE_DEVICE
bool operator!=(const EndIterator&) const { return valid_; }

private:
bool valid_ = true;
};

CUTE_DEVICE
SingleTileScheduler(const Params& params) {}

CUTE_DEVICE
Iterator begin() const { return {}; }

CUTE_DEVICE
EndIterator end() const { return {}; }
};

} // namespace llm