diff --git a/CMakeLists.txt b/CMakeLists.txt index 0465e26a..46a05e2e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,7 +226,7 @@ list(APPEND CUDA_NVCC_FLAGS -compress-all) # Enable aggresive fatbin compress for CUDA 12.8 or later. -if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL 12.8) +if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) list(APPEND CUDA_NVCC_FLAGS -compress-mode=size) endif() message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}") @@ -237,7 +237,6 @@ include(CTest) include(GoogleTest) # include current path -list(APPEND COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}) list(APPEND COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src) list(APPEND COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/third_party) diff --git a/src/kernels/gemm/CMakeLists.txt b/src/kernels/gemm/CMakeLists.txt index 03a15eec..01fc7e5f 100644 --- a/src/kernels/gemm/CMakeLists.txt +++ b/src/kernels/gemm/CMakeLists.txt @@ -13,16 +13,25 @@ cc_library( cutlass ) +cc_test( + NAME + tile_scheduler_test + SRCS + tile_scheduler_test.cu + DEPS + :gtest_main + absl::random_random + cutlass +) cc_test( NAME gemm_kernel_test SRCS - tile_scheduler_test.cu sm80_grouped_gemm_test.cu DEPS :gemm.kernels - absl::random_random :gtest_main + absl::random_random torch ) diff --git a/src/kernels/gemm/tile_scheduler.cuh b/src/kernels/gemm/tile_scheduler.cuh index 9cadd501..834aaef4 100644 --- a/src/kernels/gemm/tile_scheduler.cuh +++ b/src/kernels/gemm/tile_scheduler.cuh @@ -60,9 +60,139 @@ class SingleTileScheduler { EndIterator end() const { return {}; } }; -enum class RasterOrder { AlongM, AlongN }; +enum class RasterOrder : int8_t { AlongM, AlongN }; class StaticPersistentTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + int cluster_size = 0; + int grid_shape_m = 0; + int grid_shape_n = 0; + int swizzle = 0; + RasterOrder raster_order = RasterOrder::AlongM; + }; + + static dim3 get_grid_shape(Arguments const& args, int n_sms) { + return {(uint32_t)n_sms}; + } + + // Device side kernel params + using Params = Arguments; + static Params to_underlying_arguments(const Arguments& args) { return args; } + + class Iterator { + public: + CUTE_DEVICE + Iterator(int start, + int step, + const StaticPersistentTileScheduler* scheduler) + : linear_idx_(start), step_(step), this_(scheduler) {} + + CUTE_DEVICE + cute::tuple operator*() const { + return this_->swizzle_and_rasterize_1d(linear_idx_); + } + + CUTE_DEVICE + Iterator& operator++() { + linear_idx_ += step_; + return *this; + } + + CUTE_DEVICE + bool operator!=(const Iterator& e) const { + return linear_idx_ < e.linear_idx_; + } + + private: + int linear_idx_; + int step_; + const StaticPersistentTileScheduler* this_; + }; + + // Constructor for unit tests + CUTE_HOST_DEVICE StaticPersistentTileScheduler(const Params& params, + int start, + int step) + : params_(params), start_(start), step_(step) { + major_blocks_ = params.raster_order == RasterOrder::AlongM + ? params.grid_shape_m + : params.grid_shape_n; + minor_blocks_ = params.raster_order == RasterOrder::AlongM + ? params.grid_shape_n + : params.grid_shape_m; + panel_size_ = (major_blocks_ * params.swizzle); + } + + CUTE_DEVICE StaticPersistentTileScheduler(const Params& params) + : StaticPersistentTileScheduler(params, + blockIdx.x, + gridDim.x * gridDim.y * gridDim.z) {} + + CUTE_DEVICE + Iterator begin() const { return {start_, step_, this}; } + + CUTE_DEVICE + Iterator end() const { + const int problem_tiles = params_.grid_shape_m * params_.grid_shape_n; + return {problem_tiles, 0, this}; + } + + // compute tile coord from linear idx + CUTE_HOST_DEVICE cute::tuple swizzle_and_rasterize_1d( + int linear_idx) const { + int panel_idx, panel_offset; + // panel_idx = linear_idx / panel_size_; + // panel_offset = linear_idx % panel_size_; + panel_size_.divmod(linear_idx, panel_idx, panel_offset); + + int minor_base = panel_idx * params_.swizzle; + const int remaining = minor_blocks_ - minor_base; + // handle last partial panel + int swizzle = std::min(params_.swizzle, remaining); + // fix unaligned tma multicast + if (params_.cluster_size > 1 && (swizzle & 1)) { + const int aligned_swizzle = swizzle ^ 1; + const int aligned_panel_size = major_blocks_ * aligned_swizzle; + if (panel_offset < aligned_panel_size) { + // the index blongs to the aligned panel + swizzle = aligned_swizzle; + } else { + // the index belongs to next tiny panel + panel_idx += 1; + swizzle = 1; + panel_offset -= aligned_panel_size; + minor_base += aligned_swizzle; + } + } + + // Convert linear idx within panel into cta coord + int major_idx = panel_offset / swizzle; + const int minor_idx = (panel_offset % swizzle) + minor_base; + if (panel_idx & 1) { + // odd idx within panel, reverse minor index + major_idx = (major_blocks_ - 1 - major_idx); + } + + if (params_.raster_order == RasterOrder::AlongM) { + return {major_idx, minor_idx}; + } + return {minor_idx, major_idx}; + } + + private: + Params params_; + int start_ = 0; + int step_ = 1; + + // derived params for performance + int major_blocks_ = 0; + int minor_blocks_ = 0; + FastDivmod panel_size_; +}; + +class StaticPersistentTileScheduler2D { public: // Host side kernel arguments struct Arguments { @@ -90,7 +220,7 @@ class StaticPersistentTileScheduler { CUTE_DEVICE cute::tuple operator*() const { - return swizzle_and_rasterize(linear_idx_, params_); + return swizzle_and_rasterize_2d(linear_idx_, params_); } CUTE_DEVICE @@ -110,7 +240,7 @@ class StaticPersistentTileScheduler { const Params& params_; }; - CUTE_DEVICE StaticPersistentTileScheduler(const Params& params) + CUTE_DEVICE StaticPersistentTileScheduler2D(const Params& params) : params_(params) { linear_idx_ = blockIdx.x; grid_size_ = gridDim.x * gridDim.y * gridDim.z; @@ -132,28 +262,28 @@ class StaticPersistentTileScheduler { // ^ +--+--+--+--+ ^ // | |00|04|08|12| | <---- N ----> <- N -> // C +--+--+--+--+ | <- S -> - // | |01|05|09|13| | +--+--+--+--+ ^ +--+--+ ^ + // | |01|05|09|13| | +--+--+--+--+ | +--+--+ | // v +--+--+--+--+ M ---> |00|02|04|06| | ---> |00|02| | // |02|06|10|14| | +--+--+--+--+ M +--+--+ M // +--+--+--+--+ | |01|03|05|07| | |01|03| | // |03|07|11|15| | +--+--+--+--+ v +--+--+ v // +--+--+--+--+ v // | - // v - // ^ +--+--+--+--+ - // | |00|02|12|14| <---- N ----> - // C +--+--+--+--+ <- S -> - // | |01|03|13|15| +--+--+--+--+ ^ - // v +--+--+--+--+ <--- |00|01|04|05| | - // |04|06|08|10| +--+--+--+--+ M - // +--+--+--+--+ |02|03|06|07| | - // |05|07|09|11| +--+--+--+--+ v - // +--+--+--+--+ + // <- S -> v + // ^ +--+--+--+--+ | + // | |00|02|12|14| | <---- N ----> + // C +--+--+--+--+ | <- S -> + // | |01|03|13|15| | +--+--+--+--+ | + // v +--+--+--+--+ M <--- |00|01|04|05| | + // |04|06|08|10| | +--+--+--+--+ M + // +--+--+--+--+ | |02|03|06|07| | + // |05|07|09|11| | +--+--+--+--+ v + // +--+--+--+--+ v // // Expand Cluster Expand Swizzle // compute tile coord from linear idx - CUTE_HOST_DEVICE static cute::tuple swizzle_and_rasterize( + CUTE_HOST_DEVICE static cute::tuple swizzle_and_rasterize_2d( int linear_idx, const Params& params) { // number of ctas per cluster diff --git a/src/kernels/gemm/tile_scheduler_test.cu b/src/kernels/gemm/tile_scheduler_test.cu index 14b27f8b..0dba4cbc 100644 --- a/src/kernels/gemm/tile_scheduler_test.cu +++ b/src/kernels/gemm/tile_scheduler_test.cu @@ -14,9 +14,8 @@ class TileSchedulerTest int32_t /*swizzle*/, RasterOrder /*order*/>> {}; -// StaticPersistentTileScheduler -TEST_P(TileSchedulerTest, StaticPersistent) { - using TileScheduler = StaticPersistentTileScheduler; +TEST_P(TileSchedulerTest, StaticPersistent2D) { + using TileScheduler = StaticPersistentTileScheduler2D; using namespace cute; const auto [cluster_m, cluster_n, grid_m, grid_n, swizzle, order] = @@ -31,13 +30,15 @@ TEST_P(TileSchedulerTest, StaticPersistent) { // make_tensor(mapping_data.data(), // make_shape(params.grid_shape_m, params.grid_shape_n)); int pre_tile_m = 0, pre_tile_n = 0; - const int max_dist = order == RasterOrder::AlongM - ? (swizzle * cluster_n) + cluster_m - : (swizzle * cluster_m) + cluster_n; + int max_dist = swizzle; + if (cluster_m > 1 || cluster_n > 1) { + max_dist = order == RasterOrder::AlongM ? (swizzle * cluster_n) + cluster_m + : (swizzle * cluster_m) + cluster_n; + } int32_t valid = 0; for (int linear_idx = 0; linear_idx < problem_tiles; ++linear_idx) { const auto [tile_m, tile_n] = - TileScheduler::swizzle_and_rasterize(linear_idx, params); + TileScheduler::swizzle_and_rasterize_2d(linear_idx, params); const int dist = std::abs(tile_m - pre_tile_m) + std::abs(tile_n - pre_tile_n); @@ -56,14 +57,59 @@ TEST_P(TileSchedulerTest, StaticPersistent) { // print_tensor(mapping); } +TEST_P(TileSchedulerTest, StaticPersistent1D) { + using TileScheduler = StaticPersistentTileScheduler; + using namespace cute; + + const auto [cluster_m, cluster_n, grid_m, grid_n, swizzle, order] = + GetParam(); + + const int cluster_size = cluster_m * cluster_n; + + // Skip test if swizzle is not a multiple of cluster size + if (swizzle % cluster_size != 0) { + return; + } + + TileScheduler::Params params{cluster_size, grid_m, grid_n, swizzle, order}; + TileScheduler scheduler(params, /*start=*/0, /*step=*/1); + + const int problem_tiles = params.grid_shape_m * params.grid_shape_n; + // std::vector mapping_data(problem_tiles); + // auto mapping = + // make_tensor(mapping_data.data(), + // make_shape(params.grid_shape_m, params.grid_shape_n)); + int pre_tile_m = 0, pre_tile_n = 0; + int32_t valid = 0; + for (int linear_idx = 0; linear_idx < problem_tiles; ++linear_idx) { + const auto [tile_m, tile_n] = + scheduler.swizzle_and_rasterize_1d(linear_idx); + + const int dist = + std::abs(tile_m - pre_tile_m) + std::abs(tile_n - pre_tile_n); + pre_tile_m = tile_m; + pre_tile_n = tile_n; + EXPECT_LE(dist, swizzle); + // mapping(tile_m, tile_n) = linear_idx; + + // (grid_m, grid_n):(1, grid_m) + const int idx = tile_m + (tile_n * grid_m); + valid ^= idx; + valid ^= linear_idx; + } + EXPECT_EQ(valid, 0); + + // print_tensor(mapping); +} + INSTANTIATE_TEST_SUITE_P( SM80, TileSchedulerTest, - ::testing::Combine(::testing::Values(1, 2), // cluster_m - ::testing::Values(1, 2), // cluster_n - ::testing::Values(8, 16), // grid_m - ::testing::Values(8, 16), // grid_n - ::testing::Values(1, 2, 4), // swizzle + ::testing::Combine(::testing::Values(1, 2), // cluster_m + ::testing::Values(1, 2), // cluster_n + ::testing::Values(8, 16), // grid_m + ::testing::Values(8, 16), // grid_n + ::testing::Values(1, 4, 8, 16), // swizzle ::testing::Values(RasterOrder::AlongM, RasterOrder::AlongN) // order )); diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 5595d8c7..a5b75a92 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -1,23 +1,15 @@ include(cc_library) cc_library( - NAME + NAME cutlass INCLUDES cutlass/include - DEPS + DEPS torch # TODO: depends on CUDA instead of torch ) -cc_library( - NAME - flashinfer - INCLUDES - flashinfer/include -) - add_subdirectory(sentencepiece) if (BUILD_NVBENCH) add_subdirectory(nvbench) endif() -