Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/kernels/gemm/fast_cast.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cute/numeric/numeric_types.hpp>
#include <cute/tensor.hpp>

namespace llm {

namespace detail {
using namespace cute;
template <typename SRC_TYPE, typename DST_TYPE>
struct type_cast {
static_assert(dependent_false<SRC_TYPE>, "not implemented");
};

// specialization for float -> half
template <>
struct type_cast<float, cute::half_t> {
template <typename FragmentS, typename FragmentD>
CUTE_DEVICE static void cast(const FragmentS& src, FragmentD& dst) {
auto src2 = recast<float2>(src);
auto dst2 = recast<half2>(dst);

CUTE_UNROLL
for (int i = 0; i < size(src2); ++i) {
dst2(i) = __float22half2_rn(src2(i));
}
}
};

// specialization for float -> bfloat16
template <>
struct type_cast<float, cute::bfloat16_t> {
template <typename FragmentS, typename FragmentD>
CUTE_DEVICE static void cast(const FragmentS& src, FragmentD& dst) {
auto src2 = recast<float2>(src);
auto dst2 = recast<nv_bfloat162>(dst);

CUTE_UNROLL
for (int i = 0; i < size(src2); ++i) {
dst2(i) = __float22bfloat162_rn(src2(i));
}
}
};
// TODO: add other specializations

} // namespace detail

// dispatch to right type_cast
// functionality: dst = (DST_TYPE)src
template <typename FragmentS, typename FragmentD>
CUTE_DEVICE void fast_cast(const FragmentS& src, FragmentD& dst) {
CUTE_STATIC_ASSERT_V((cute::size(src) == cute::size(dst)), "size mismatch");

using TypeSrc = typename FragmentS::value_type;
using TypeDst = typename FragmentD::value_type;

// dispatch to type_cast
detail::type_cast<TypeSrc, TypeDst>::cast(src, dst);
}

} // namespace llm
7 changes: 4 additions & 3 deletions src/kernels/gemm/gather_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/common/gather_tensor.hpp
#pragma once

#include "cute/layout.hpp"
#include "cute/layout_composed.hpp"
#include "cute/tensor.hpp"
#include <cute/layout.hpp>
#include <cute/layout_composed.hpp>
#include <cute/tensor.hpp>

namespace llm {

using namespace cute;
Expand Down
136 changes: 101 additions & 35 deletions src/kernels/gemm/grouped_gemm_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_cast.cuh"
#include "gather_tensor.hpp"
#include "safe_copy.hpp"

namespace llm {
using namespace cute;
Expand Down Expand Up @@ -151,13 +153,12 @@ struct GEMMParams {
int n = 0;
int k = 0;
int topk = 0;
int n_experts = 0;

int m_blocks = 0;
int n_blocks = 0;
};

template <typename Traits, typename Params>
template <bool EVEN_N, bool EVEN_K, typename Traits, typename Params>
__global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
__grid_constant__ const Params params) {
// Traits
Expand All @@ -169,7 +170,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
using _BLK_N = Int<kBlockN>;
using _BLK_K = Int<kBlockK>;

using DTYPE = typename Traits::DType;
using DType = typename Traits::DType;
using TiledMma = typename Traits::TiledMma;

using SmemLayoutA = typename Traits::SmemLayoutA;
Expand All @@ -187,52 +188,67 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
const auto M = kBlockM * gridDim.x;
const auto N = params.n;
const auto K = params.k;

const auto topk = params.topk;
const auto n_experts = params.n_experts;

// each thread block takes care of one block: (BLK_M, BLK_N)
const auto m_block_idx = blockIdx.x;
const auto n_block_idx = blockIdx.y;
const auto tidx = threadIdx.x;

const int expert_id = params.expert_ids_ptr[m_block_idx];
const int n_flatten_tokens = params.m * topk;

// ProblemShape
const int* sorted_token_idxes = params.sorted_token_idxes_ptr;
auto idx_to_t_idx = [sorted_token_idxes, topk](int idx) {
return sorted_token_idxes[idx] / topk;
};
// A: (M, K), k-major
auto A = make_gather_tensor(make_gmem_ptr((const DTYPE*)params.a_ptr),
auto A = make_gather_tensor(make_gmem_ptr((const DType*)params.a_ptr),
make_shape(M, K),
make_stride(get<0>(params.a_stride), _1{}),
idx_to_t_idx);

// B: (N, K), k-major
const auto b_offset = expert_id * get<0>(params.b_stride);
auto B = make_tensor(make_gmem_ptr((const DTYPE*)params.b_ptr + b_offset),
auto B = make_tensor(make_gmem_ptr((const DType*)params.b_ptr + b_offset),
make_shape(N, K),
make_stride(get<1>(params.b_stride), _1{}));

// C: (M, N), n-major
auto idx_to_f_idx = [sorted_token_idxes](int idx) {
return sorted_token_idxes[idx];
};
auto C = make_gather_tensor(make_gmem_ptr((DTYPE*)params.c_ptr),
auto C = make_gather_tensor(make_gmem_ptr((DType*)params.c_ptr),
make_shape(M, N),
make_stride(get<0>(params.c_stride), _1{}),
idx_to_f_idx);

auto max_coord_mk = make_coord(M, K);
auto max_coord_nk = make_coord(N, K);
auto max_coord_mn = make_coord(M, N);

// (M, K) => (BLK_M, BLK_K, k)
Tensor gA =
local_tile(A, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
// (BLK_M, BLK_K, k) => (M, K)
Tensor cA = local_tile(make_identity_tensor(make_shape(M, K)),
Shape<_BLK_M, _BLK_K>{},
make_coord(m_block_idx, _));
// (N, K) => (BLK_N, BLK_K, k)
Tensor gB =
local_tile(B, Shape<_BLK_N, _BLK_K>{}, make_coord(n_block_idx, _));
// (BLK_N, BLK_K, k) => (N, K)
Tensor cB = local_tile(make_identity_tensor(make_shape(N, K)),
Shape<_BLK_N, _BLK_K>{},
make_coord(n_block_idx, _));
// (M, N) => (BLK_M, BLK_N)
Tensor gC = local_tile(
C, Shape<_BLK_M, _BLK_N>{}, make_coord(m_block_idx, n_block_idx));
// (BLK_M, BLK_N) => (M, N)
Tensor cC = local_tile(make_identity_tensor(make_shape(M, N)),
Shape<_BLK_M, _BLK_N>{},
make_coord(m_block_idx, n_block_idx));

// Smem
extern __shared__ char smem[];
Expand All @@ -249,16 +265,63 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
GmemTiledCopy gmem_tiled_copy;
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx);

// (BLK_M, BLK_K, k) => (COPY, CP_M, CP_K, k)
// (BLK_M, BLK_K, k) => (CPY, CPY_M, CPY_K, k)
auto tAgA = gmem_thr_copy.partition_S(gA);
// (BLK_M, BLK_K, PIPE) => (COPY, CP_M, CP_K, PIPE)
// (CPY, CPY_M, CPY_K, k) => (M, K)
auto tAcA = gmem_thr_copy.partition_S(cA);
// (BLK_M, BLK_K, PIPE) => (CPY, CPY_M, CPY_K, PIPE)
auto tAsA = gmem_thr_copy.partition_D(sA);

// (BLK_N, BLK_K, k) => (COPY, CP_N, CP_K, k)
// (CPY_M) => (M, K)
auto tAcA_m = tAcA(_0{}, _, _0{}, _0{});
auto tApA = make_tensor<bool>(make_shape(size(tAcA_m)));
CUTE_UNROLL
for (int i = 0; i < size(tApA); ++i) {
const auto f_idx = sorted_token_idxes[get<0>(tAcA_m(i))];
tApA(i) = f_idx < n_flatten_tokens;
}

// (BLK_N, BLK_K, k) => (CPY, CPY_N, CPY_K, k)
auto tBgB = gmem_thr_copy.partition_S(gB);
// (BLK_N, BLK_K, PIPE) => (COPY, CP_N, CP_K, PIPE)
// (CPY, CPY_N, CPY_K, k) => (N, K)
auto tBcB = gmem_thr_copy.partition_S(cB);
// (BLK_N, BLK_K, PIPE) => (CPY, CPY_N, CPY_K, PIPE)
auto tBsB = gmem_thr_copy.partition_D(sB);

auto produce_ab = [&](int k_tile, int k_pipe) {
safe_copy_m<EVEN_K, /*ZFILL_M=*/true, /*ZFILL_K=*/true>(
gmem_tiled_copy,
tAgA(_, _, _, k_tile),
tAsA(_, _, _, k_pipe),
tApA,
tAcA(_, _, _, k_tile),
max_coord_mk);

safe_copy_n<EVEN_N, EVEN_K, /*ZFILL_N=*/true, /*ZFILL_K=*/true>(
gmem_tiled_copy,
tBgB(_, _, _, k_tile),
tBsB(_, _, _, k_pipe),
tBcB(_, _, _, k_tile),
max_coord_nk);
};

auto produce_ab_no_oob = [&](int k_tile, int k_pipe) {
safe_copy_m<EVEN_K, /*ZFILL_M=*/false, /*ZFILL_K=*/true>(
gmem_tiled_copy,
tAgA(_, _, _, k_tile),
tAsA(_, _, _, k_pipe),
tApA,
tAcA(_, _, _, k_tile),
max_coord_mk);

safe_copy_n<EVEN_N, EVEN_K, /*ZFILL_N=*/false, /*ZFILL_K=*/true>(
gmem_tiled_copy,
tBgB(_, _, _, k_tile),
tBsB(_, _, _, k_pipe),
tBcB(_, _, _, k_tile),
max_coord_nk);
};

// GEMM: C = A@B.T
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Expand All @@ -270,49 +333,52 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
// s2r tiled copy for A and B
auto smem_tiled_copy_a = SmemTiledCopyA{};
auto smem_thr_copy_a = smem_tiled_copy_a.get_thread_slice(tidx);
// (BLK_M, BLK_K, PIPE) => (COPY, COPY_M, COPY_K, PIPE)
// (BLK_M, BLK_K, PIPE) => (CPY, CPY_M, CPY_K, PIPE)
auto tCsA = smem_thr_copy_a.partition_S(sA);
// (COPY, COPY_M, COPY_K)
// (CPY, CPY_M, CPY_K)
auto tCrA_cpv = smem_thr_copy_a.retile_D(tCrA);

auto smem_tiled_copy_b = SmemTiledCopyB{};
auto smem_thr_copy_b = smem_tiled_copy_b.get_thread_slice(tidx);
// (BLK_N, BLK_K, PIPE) => (COPY, COPY_N, COPY_K, PIPE)
// (BLK_N, BLK_K, PIPE) => (CPY, CPY_N, CPY_K, PIPE)
auto tCsB = smem_thr_copy_b.partition_S(sB);
// (COPY, COPY_N, COPY_K)
// (CPY, CPY_N, CPY_K)
auto tCrB_cpv = smem_thr_copy_b.retile_D(tCrB);

// ############### Prologue ###############
// remaining k-tile count
int k_tiles_remaining = size<3>(tAgA);
// next tile index in gmem to read from
int k_tile_next = 0;
int k_tile = 0;

// async loads for all pipes except the last one
auto kPipe = size<3>(tAsA);
CUTE_UNROLL
for (int k_pipe = 0; k_pipe < kPipe - 1; ++k_pipe) {
copy(gmem_tiled_copy, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, k_pipe));
copy(gmem_tiled_copy, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, k_pipe));
if (k_pipe == 0) {
produce_ab(k_tile, k_pipe);
} else {
produce_ab_no_oob(k_tile, k_pipe);
}
cp_async_fence();

// advance to next k-tile
if (--k_tiles_remaining > 0) {
++k_tile_next;
++k_tile;
}
}

// ############### Mainloop ###############
// (BLK_M, BLK_N) => (MMA, MMA_M, MMA_N)
auto tCrC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
cute::clear(tCrC); // Clear the accumulator
auto tCrAccC = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
cute::clear(tCrAccC); // Clear the accumulator

// pipe index in smem to read from
int pipe_read = 0;
// pipe index in smem to write to
int pipe_write = kPipe - 1;

// pipe to read from: (COPY, COPY_N, COPY_K)
// pipe to read from: (CPY, CPY_N, CPY_K)
Tensor tCsA_p = tCsA(_, _, _, pipe_read);
Tensor tCsB_p = tCsB(_, _, _, pipe_read);

Expand All @@ -337,17 +403,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
// first block
if (ki == 0) {
// copy gmem to smem for next pipe
copy(gmem_tiled_copy,
tAgA(_, _, _, k_tile_next),
tAsA(_, _, _, pipe_write));
copy(gmem_tiled_copy,
tBgB(_, _, _, k_tile_next),
tBsB(_, _, _, pipe_write));
produce_ab_no_oob(k_tile, pipe_write);
cp_async_fence();

// advance to next k-tile
if (--k_tiles_remaining > 0) {
++k_tile_next;
++k_tile;
}
}
// last block
Expand All @@ -371,15 +432,17 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
copy(smem_tiled_copy_b, tCsB_p(_, _, ki_next), tCrB_cpv(_, _, ki_next));

// thread-level gemm for ki
gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrC);
gemm(tiled_mma, tCrA(_, _, ki), tCrB(_, _, ki), tCrAccC);
}
}

// ############### Epilogue ###############
// (BLK_M, BLK_N)
Tensor sC = make_tensor(make_smem_ptr(ss.c_smem.data()), SmemLayoutC{});

// TODO: fastcast tCrC to DTYPE
// fastcast tCrAccC to DType
auto tCrC = make_tensor_like<DType>(tCrAccC);
fast_cast(tCrAccC, tCrC);

// copy tCrC from registers to smem
SmemTiledCopyC smem_tiled_copy_c;
Expand All @@ -396,16 +459,19 @@ __global__ __launch_bounds__(Traits::kThreadNum) void grouped_gemm_kernel_sm80(
auto gmem_thr_copy_c = gmem_tiled_copy_c.get_thread_slice(tidx);
auto tGsC = gmem_thr_copy_c.partition_S(sC);
auto tGgC = gmem_thr_copy_c.partition_D(gC);
cute::copy(gmem_tiled_copy_c, tGsC, tGgC);
// (CPY, CPY_M, CPY_N) => (M, N)
auto tGcC = gmem_thr_copy_c.partition_D(cC);
safe_copy_m<EVEN_N, /*ZFILL_M=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy_c, tGsC, tGgC, tApA, tGcC, max_coord_mn);
}

template <typename Traits, typename Params>
template <bool EVEN_N, bool EVEN_K, typename Traits, typename Params>
void launch_grouped_gemm_kernel_sm80(const Params& params,
cudaStream_t stream) {
const auto smem_size = sizeof(GEMMSharedStorageSM80<Traits>);
// std::cout << "SMEM size: " << smem_size << " bytes\n";

auto gemm_kernel = grouped_gemm_kernel_sm80<Traits, Params>;
auto gemm_kernel = grouped_gemm_kernel_sm80<EVEN_N, EVEN_K, Traits, Params>;
cudaFuncSetAttribute(
gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
// TODO: support persistent kernels
Expand Down
Loading