diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt index 97f28a80..a5d735c7 100644 --- a/src/kernels/gemm/CMakeLists.txt +++ b/src/kernels/gemm/CMakeLists.txt @@ -18,6 +18,7 @@ cc_test( NAME gemm_kernel_test SRCS + tile_scheduler_test.cu sm80_grouped_gemm_test.cu DEPS :gemm.kernels diff --git a/src/kernels/gemm/fast_math.h b/src/kernels/gemm/fast_math.h new file mode 100644 index 00000000..6dfeb755 --- /dev/null +++ b/src/kernels/gemm/fast_math.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include + +namespace llm { +struct FastDivmod { + int32_t div_ = 1; + uint32_t mul_ = 0u; + uint32_t shr_ = 0u; + + CUTE_HOST_DEVICE + void reset(int div) { + div_ = div; + if (div_ != 1) { + unsigned int p = 31 + log2(div_); + unsigned m = + unsigned(((1ull << p) + unsigned(div_) - 1) / unsigned(div_)); + + mul_ = m; + shr_ = p - 32; + } + } + + constexpr FastDivmod() = default; + + CUTE_HOST_DEVICE + FastDivmod(int div) { reset(div); } + + CUTE_HOST_DEVICE + FastDivmod& operator=(int div) { + reset(div); + return *this; + } + + CUTE_HOST_DEVICE + void divmod(int src, int& quo, int& rem) const { + quo = div(src); + rem = src - (quo * div_); + } + + CUTE_HOST_DEVICE + int div(int src) const { +#if defined(__CUDA_ARCH__) + return (div_ != 1) ? __umulhi(src, mul_) >> shr_ : src; +#else + return src / div_; +#endif + } + + CUTE_HOST_DEVICE + int mod(int src) const { +#if defined(__CUDA_ARCH__) + return div_ != 1 ? src - (div(src) * div_) : 0; +#else + return src % div_; +#endif + } + + CUTE_HOST_DEVICE + operator int() const { return div_; } +}; + +// operator overloads for FastDivmod +CUTE_HOST_DEVICE int operator/(int src, const FastDivmod& d) { + return d.div(src); +} +CUTE_HOST_DEVICE int operator%(int src, const FastDivmod& d) { + return d.mod(src); +} + +} // namespace llm diff --git a/src/kernels/gemm/tile_scheduler.cuh b/src/kernels/gemm/tile_scheduler.cuh index 5388e5aa..9cadd501 100644 --- a/src/kernels/gemm/tile_scheduler.cuh +++ b/src/kernels/gemm/tile_scheduler.cuh @@ -3,9 +3,12 @@ #include #include +#include #include #include +#include "fast_math.h" + namespace llm { class SingleTileScheduler { @@ -57,4 +60,161 @@ class SingleTileScheduler { EndIterator end() const { return {}; } }; +enum class RasterOrder { AlongM, AlongN }; + +class StaticPersistentTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + FastDivmod cluster_shape_m = 0; + FastDivmod cluster_shape_n = 0; + int grid_shape_m = 0; + int grid_shape_n = 0; + FastDivmod swizzle = 0; + RasterOrder raster_order = RasterOrder::AlongM; + }; + + static dim3 get_grid_shape(Arguments const& args, int n_sms) { + return {(uint32_t)n_sms}; + } + + // Device side kernel params + using Params = Arguments; + static Params to_underlying_arguments(const Arguments& args) { return args; } + + class Iterator { + public: + CUTE_DEVICE + Iterator(int start, int step, const Params& params) + : linear_idx_(start), step_(step), params_(params) {} + + CUTE_DEVICE + cute::tuple operator*() const { + return swizzle_and_rasterize(linear_idx_, params_); + } + + CUTE_DEVICE + Iterator& operator++() { + linear_idx_ += step_; + return *this; + } + + CUTE_DEVICE + bool operator!=(const Iterator& e) const { + return linear_idx_ < e.linear_idx_; + } + + private: + int linear_idx_; + int step_; + const Params& params_; + }; + + CUTE_DEVICE StaticPersistentTileScheduler(const Params& params) + : params_(params) { + linear_idx_ = blockIdx.x; + grid_size_ = gridDim.x * gridDim.y * gridDim.z; + } + + CUTE_DEVICE + Iterator begin() const { return {linear_idx_, grid_size_, params_}; } + + CUTE_DEVICE + Iterator end() const { + const int problem_tiles = params_.grid_shape_m * params_.grid_shape_n; + return {problem_tiles, 0, params_}; + } + + // For example: ClusterShape: (2, 1), GridShape(4, 4) and Swizzle=2. Along M + // + // Reduce Cluster Reduce Swizzle + // <---- N ----> + // ^ +--+--+--+--+ ^ + // | |00|04|08|12| | <---- N ----> <- N -> + // C +--+--+--+--+ | <- S -> + // | |01|05|09|13| | +--+--+--+--+ ^ +--+--+ ^ + // v +--+--+--+--+ M ---> |00|02|04|06| | ---> |00|02| | + // |02|06|10|14| | +--+--+--+--+ M +--+--+ M + // +--+--+--+--+ | |01|03|05|07| | |01|03| | + // |03|07|11|15| | +--+--+--+--+ v +--+--+ v + // +--+--+--+--+ v + // | + // v + // ^ +--+--+--+--+ + // | |00|02|12|14| <---- N ----> + // C +--+--+--+--+ <- S -> + // | |01|03|13|15| +--+--+--+--+ ^ + // v +--+--+--+--+ <--- |00|01|04|05| | + // |04|06|08|10| +--+--+--+--+ M + // +--+--+--+--+ |02|03|06|07| | + // |05|07|09|11| +--+--+--+--+ v + // +--+--+--+--+ + // + // Expand Cluster Expand Swizzle + + // compute tile coord from linear idx + CUTE_HOST_DEVICE static cute::tuple swizzle_and_rasterize( + int linear_idx, + const Params& params) { + // number of ctas per cluster + const int cluster_size = params.cluster_shape_m * params.cluster_shape_n; + // number of cluster along major + const int major_clusters = + params.raster_order == RasterOrder::AlongM + ? params.grid_shape_m / params.cluster_shape_m + : params.grid_shape_n / params.cluster_shape_n; + + // Convert linear CTA idx/coord into cluster coord. + // Layout: (cluster_size, clusters):(1, cluster_size) + const int cluster_idx = linear_idx / cluster_size; + const int cluster_offset = linear_idx % cluster_size; + + // Convert linear idx within cluster into cta coord + // Layout: (cluster_shape_m, cluster_shape_n):(1, cluster_shape_m) + int cluster_offset_m, cluster_offset_n; + params.cluster_shape_m.divmod( + cluster_offset, cluster_offset_n, cluster_offset_m); + + int major_idx, minor_idx, panel_idx; + if (params.swizzle > 1) { + // Convert cluster linear idx into swizzled coord + // Layout: (swizzle, panels): (1, swizzle) + int swizzle_idx, swizzle_offset; + params.swizzle.divmod(cluster_idx, swizzle_idx, swizzle_offset); + + major_idx = swizzle_idx % major_clusters; + panel_idx = swizzle_idx / major_clusters; + minor_idx = panel_idx * params.swizzle + swizzle_offset; + } else { + // no swizzle, panel size = 1 + major_idx = cluster_idx % major_clusters; + panel_idx = cluster_idx / major_clusters; + minor_idx = panel_idx; + } + + if ((panel_idx & 1) != 0) { + // odd idx within panel, reverse major index + major_idx = (major_clusters - 1 - major_idx); + } + + if (params.raster_order == RasterOrder::AlongM) { + // Map the swizzled cluster tile back to a CTA tile + major_idx = major_idx * params.cluster_shape_m + cluster_offset_m; + minor_idx = minor_idx * params.cluster_shape_n + cluster_offset_n; + return {major_idx, minor_idx}; + } + + // raster_order == AlongN + // Map the swizzled cluster tile back to a CTA tile + minor_idx = minor_idx * params.cluster_shape_m + cluster_offset_m; + major_idx = major_idx * params.cluster_shape_n + cluster_offset_n; + return {minor_idx, major_idx}; + } + + private: + int linear_idx_ = 0; + int grid_size_ = 0; + Params params_; +}; + } // namespace llm diff --git a/src/kernels/gemm/tile_scheduler_test.cu b/src/kernels/gemm/tile_scheduler_test.cu new file mode 100644 index 00000000..8f722321 --- /dev/null +++ b/src/kernels/gemm/tile_scheduler_test.cu @@ -0,0 +1,71 @@ +#include + +#include + +#include "tile_scheduler.cuh" + +namespace llm { + +class TileSchedulerTest + : public ::testing::TestWithParam> {}; + +// StaticPersistentTileScheduler +TEST_P(TileSchedulerTest, StaticPersistent) { + using TileScheduler = StaticPersistentTileScheduler; + using namespace cute; + + const auto [cluster_m, cluster_n, grid_m, grid_n, swizzle, order] = + GetParam(); + + TileScheduler::Params params{ + cluster_m, cluster_n, grid_m, grid_n, swizzle, order}; + + const int problem_tiles = params.grid_shape_m * params.grid_shape_n; + // std::vector mapping_data(problem_tiles); + // auto mapping = + // make_tensor(mapping_data.data(), + // make_shape(params.grid_shape_m, params.grid_shape_n)); + int pre_tile_m = 0, pre_tile_n = 0; + const int max_dist = order == RasterOrder::AlongM + ? (swizzle * cluster_n) + cluster_m + : (swizzle * cluster_m) + cluster_n; + int32_t valid = 0; + for (int linear_idx = 0; linear_idx < problem_tiles; ++linear_idx) { + const auto [tile_m, tile_n] = + TileScheduler::swizzle_and_rasterize(linear_idx, params); + + const int dist = + std::abs(tile_m - pre_tile_m) + std::abs(tile_n - pre_tile_n); + pre_tile_m = tile_m; + pre_tile_n = tile_n; + EXPECT_LE(dist, max_dist); + // mapping(tile_m, tile_n) = linear_idx; + + // (grid_m, grid_n):(1, grid_m) + const int idx = tile_m + (tile_n * grid_m); + valid ^= idx; + valid ^= linear_idx; + } + EXPECT_EQ(valid, 0); + + // print_tensor(mapping); +} + +INSTANTIATE_TEST_SUITE_P( + TileScheduler, + TileSchedulerTest, + ::testing::Combine(::testing::Values(1, 2), // cluster_m + ::testing::Values(1, 2), // cluster_n + ::testing::Values(8, 16), // grid_m + ::testing::Values(8, 16), // grid_n + ::testing::Values(1, 2, 4), // swizzle + ::testing::Values(RasterOrder::AlongM, + RasterOrder::AlongN) // order + )); + +} // namespace llm