Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ struct Sm120CollectiveEpilogue {
using Params = Arguments;

// Convert host side arguments to device side params
static Params to_underlying_arguments(Arguments const& args) { return args; }
template <class ProblemShape>
static Params to_underlying_arguments(const ProblemShape& /*problem_shape*/,
const Arguments& args,
void* /*workspace*/) {
return args;
}

template <class Block, class FrgTensor, class TiledMma>
CUTE_DEVICE void operator()(const Params& /*params*/,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ struct Sm120CollectiveFMhaWs {

// Host side arguments
struct Arguments {
// sliding window attention
int sliding_window;
// softcap
float logits_soft_cap;
// softmax scale
float sm_scale;
// alibi slopes pointer, [n_heads]
const float* alibi_slopes_ptr;
};

// Device side params
struct Params {
// sliding window attention
int sliding_window;
// softcap
Expand All @@ -132,20 +144,37 @@ struct Sm120CollectiveFMhaWs {
float sm_scale;
// softmax scale in log2
float sm_scale_log2;
// group size
const FastDivmod& group_size;

// alibi slopes pointer
const float* alibi_slopes_ptr;
};

// Device side params
using Params = Arguments;

// Convert host side arguments to device side params
static Params to_underlying_arguments(Arguments const& args) {
// no convertion needed.
return args;
template <class ProblemShape>
static Params to_underlying_arguments(const ProblemShape& /*problem_shape*/,
const Arguments& args,
void* /*workspace*/) {
float sm_scale = args.sm_scale;
float logits_soft_cap = args.logits_soft_cap;
if (logits_soft_cap != 0.0) {
// Softmax(x * sm_scale) + apply_logits_soft_cap
// => Softmax(Tanh(x * sm_scale / soft_cap) * soft_cap)
// => Softmax(S' * sm_scale') where
// S' = Tanh(x * sm_scale / soft_cap)
// = Tanh(x * soft_cap')
// soft_cap' = sm_scale / soft_cap
// sm_scale' = soft_cap
const auto sm_scale_hat = logits_soft_cap;
logits_soft_cap = sm_scale / logits_soft_cap;
sm_scale = sm_scale_hat;
}
const float sm_scale_log2 = static_cast<float>(sm_scale * M_LOG2E);
return {
.sliding_window = args.sliding_window,
.logits_soft_cap = logits_soft_cap,
.sm_scale = sm_scale,
.sm_scale_log2 = sm_scale_log2,
.alibi_slopes_ptr = args.alibi_slopes_ptr,
};
}

// load Q/K/V from gmem to smem
Expand Down Expand Up @@ -193,6 +222,7 @@ struct Sm120CollectiveFMhaWs {
const auto q_packed_len = block.get_packed_len();
const auto q_len = block.get_q_len();
const auto kv_len = block.get_kv_len();
const auto& group_size = block.get_group_size();

// Construct smem tensors
// (BLK_M, BLK_K), k-major
Expand Down Expand Up @@ -335,7 +365,7 @@ struct Sm120CollectiveFMhaWs {
// Create softmax and mask
OnlineSoftmax<kRowsPerThr> softmax(params.sm_scale_log2);
Mask<kRowsPerThr, kAlibi, kLocal> mask(
q_len, kv_len, params.group_size, params.sliding_window);
q_len, kv_len, group_size, params.sliding_window);
if constexpr (kAlibi) {
const auto tScS_mn = tScMN_mn(_, _, _0{});
mask.init_alibi(
Expand Down
90 changes: 85 additions & 5 deletions src/kernels/attention/common/fmha_block.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,93 @@
#pragma once

#include <cute/config.hpp>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "cute/config.hpp"
#include "common/fast_math.h"
#include "gather_tensor.h"

namespace llm {

using namespace cute;

// AttentionTile specialization for AttentionParams
template <typename Params,
typename TileShape, // (BLK_M, BLK_N, BLK_K)
template <typename TileShape, // (BLK_M, BLK_N, BLK_K)
typename Element, // Element type
bool kLocal>
struct FmhaBlock {
// (B, Q, H, D)
using StrideQ = Stride<int64_t, int64_t, int64_t, _1>;
using StrideO = StrideQ;
// (B, K, KH, D)
using StrideK = Stride<int64_t, int64_t, int64_t, _1>;
using StrideV = StrideK;

// Host side parameters

struct Arguments {
const void* __restrict__ q_ptr;
const void* __restrict__ k_ptr;
const void* __restrict__ v_ptr;
void* __restrict__ o_ptr;

StrideQ q_stride;
StrideK k_stride;
StrideV v_stride;
StrideO o_stride;

int sliding_window = -1; // -1 means no sliding window
};

// Device side parameters
struct Params {
const void* __restrict__ q_ptr;
const void* __restrict__ k_ptr;
const void* __restrict__ v_ptr;
void* __restrict__ o_ptr;

StrideQ q_stride;
StrideK k_stride;
StrideV v_stride;
StrideO o_stride;

int sliding_window;

// Parameters from problem shape
int q_len;
int kv_len;
int head_dim;
FastDivmod group_size;
};

template <class ProblemShape>
static Params to_underlying_arguments(const ProblemShape& problem_shape,
const Arguments& args,
void* workspace = nullptr) {
// ProblemShape: (Q, K, D, ((KH, G), B))
const int q_len = size<0>(problem_shape);
const int kv_len = size<1>(problem_shape);
const int head_dim = size<2>(problem_shape);
const int group_size = size<3, 0, 1>(problem_shape);

// TODO: construct tma_load for k/v tensors
return {
.q_ptr = args.q_ptr,
.k_ptr = args.k_ptr,
.v_ptr = args.v_ptr,
.o_ptr = args.o_ptr,
.q_stride = args.q_stride,
.k_stride = args.k_stride,
.v_stride = args.v_stride,
.o_stride = args.o_stride,
.sliding_window = args.sliding_window,
.q_len = q_len,
.kv_len = kv_len,
.head_dim = head_dim,
.group_size = FastDivmod(group_size),
};
}

static constexpr int kBlockM = get<0>(TileShape{});
static constexpr int kBlockN = get<1>(TileShape{});
static constexpr int kBlockK = get<2>(TileShape{});
Expand All @@ -26,13 +98,15 @@ struct FmhaBlock {

// hold a reference to the parameters and block coordination
const Params& params_;
// TODO: (m_block_idx, (kv_head_idx, batch_idx))
// (batch_idx, m_block_idx, kv_head_idx)
const tuple<int, int, int>& blk_coord_;

// derived parameters to avoid recomputation
int m_block_base_;
int packed_len_;

// Constructor
CUTE_HOST_DEVICE FmhaBlock(const Params& params,
const tuple<int, int, int>& blk_coord)
: params_(params), blk_coord_(blk_coord) {
Expand All @@ -50,20 +124,26 @@ struct FmhaBlock {
// returns packed_len
CUTE_HOST_DEVICE int get_packed_len() const { return packed_len_; }

// returns q_len
// returns actual query length
CUTE_HOST_DEVICE int get_q_len() const { return params_.q_len; }

// returns kv_len
// returns actual kv length
CUTE_HOST_DEVICE int get_kv_len() const { return params_.kv_len; }

// returns head_dim
CUTE_HOST_DEVICE int get_head_dim() const { return params_.head_dim; }

// returns group size
CUTE_HOST_DEVICE const FastDivmod& get_group_size() const {
return params_.group_size;
}

// returns redidue mnk
CUTE_HOST_DEVICE auto get_residue_mnk() const {
return make_tuple(packed_len_, params_.kv_len, params_.head_dim);
}

// returns (m_block_idx, (kv_head_idx, batch_idx))
// return (batch_idx, m_block_idx, kv_head_idx)
CUTE_HOST_DEVICE const auto& get_block_coord() const { return blk_coord_; }

Expand Down
34 changes: 23 additions & 11 deletions src/kernels/attention/common/tile_scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,36 @@
#include <cute/tensor.hpp>

namespace llm {

using namespace cute;
class SingleTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
// Device side kernel arguments
struct Params {
int batch_size = 0;
int m_blocks = 0;
int n_kv_heads = 0;
};
static dim3 get_grid_shape(Arguments const& args) {
return {(uint32_t)args.batch_size,
(uint32_t)args.m_blocks,
(uint32_t)args.n_kv_heads};

static dim3 get_grid_shape(Params const& params) {
return {(uint32_t)params.batch_size,
(uint32_t)params.m_blocks,
(uint32_t)params.n_kv_heads};
}

// Device side kernel params
using Params = Arguments;
static Params to_underlying_arguments(const Arguments& args) { return args; }
template <class ProblemShape, class TileShape>
static Params to_underlying_arguments(const ProblemShape& problem_shape,
const TileShape& tile_shape) {
// problem_shape: (Q, K, D, ((KH, G), B))
const int max_q_len = size<0>(problem_shape);
const int n_kv_heads = size<3, 0, 0>(problem_shape);
const int group_size = size<3, 0, 1>(problem_shape);
const int batch_size = size<3, 1>(problem_shape);

const int max_q_packed_len = max_q_len * group_size;
const int m_blocks = ceil_div(max_q_packed_len, size<0>(tile_shape));

return {batch_size, m_blocks, n_kv_heads};
}

// End Iterator tag
class EndIterator {};
Expand All @@ -34,7 +46,7 @@ class SingleTileScheduler {
Iterator() = default;

CUTE_DEVICE
cute::tuple<int, int, int> operator*() const {
tuple<int, int, int> operator*() const {
// (batch, m_blocks, kv_heads)
return {blockIdx.x, blockIdx.y, blockIdx.z};
}
Expand Down
Loading