Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ list(APPEND CUDA_NVCC_FLAGS
-compress-all)

# Enable aggresive fatbin compress for CUDA 12.8 or later.
if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL 12.8)
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
list(APPEND CUDA_NVCC_FLAGS -compress-mode=size)
endif()
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")
Expand All @@ -237,7 +237,6 @@ include(CTest)
include(GoogleTest)

# include current path
list(APPEND COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR})
list(APPEND COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src)
list(APPEND COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/third_party)

Expand Down
13 changes: 11 additions & 2 deletions src/kernels/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,25 @@ cc_library(
cutlass
)

cc_test(
NAME
tile_scheduler_test
SRCS
tile_scheduler_test.cu
DEPS
:gtest_main
absl::random_random
cutlass
)

cc_test(
NAME
gemm_kernel_test
SRCS
tile_scheduler_test.cu
sm80_grouped_gemm_test.cu
DEPS
:gemm.kernels
absl::random_random
:gtest_main
absl::random_random
torch
)
160 changes: 145 additions & 15 deletions src/kernels/gemm/tile_scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,139 @@ class SingleTileScheduler {
EndIterator end() const { return {}; }
};

enum class RasterOrder { AlongM, AlongN };
enum class RasterOrder : int8_t { AlongM, AlongN };

class StaticPersistentTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
int cluster_size = 0;
int grid_shape_m = 0;
int grid_shape_n = 0;
int swizzle = 0;
RasterOrder raster_order = RasterOrder::AlongM;
};

static dim3 get_grid_shape(Arguments const& args, int n_sms) {
return {(uint32_t)n_sms};
}

// Device side kernel params
using Params = Arguments;
static Params to_underlying_arguments(const Arguments& args) { return args; }

class Iterator {
public:
CUTE_DEVICE
Iterator(int start,
int step,
const StaticPersistentTileScheduler* scheduler)
: linear_idx_(start), step_(step), this_(scheduler) {}

CUTE_DEVICE
cute::tuple<int, int> operator*() const {
return this_->swizzle_and_rasterize_1d(linear_idx_);
}

CUTE_DEVICE
Iterator& operator++() {
linear_idx_ += step_;
return *this;
}

CUTE_DEVICE
bool operator!=(const Iterator& e) const {
return linear_idx_ < e.linear_idx_;
}

private:
int linear_idx_;
int step_;
const StaticPersistentTileScheduler* this_;
};

// Constructor for unit tests
CUTE_HOST_DEVICE StaticPersistentTileScheduler(const Params& params,
int start,
int step)
: params_(params), start_(start), step_(step) {
major_blocks_ = params.raster_order == RasterOrder::AlongM
? params.grid_shape_m
: params.grid_shape_n;
minor_blocks_ = params.raster_order == RasterOrder::AlongM
? params.grid_shape_n
: params.grid_shape_m;
panel_size_ = (major_blocks_ * params.swizzle);
}

CUTE_DEVICE StaticPersistentTileScheduler(const Params& params)
: StaticPersistentTileScheduler(params,
blockIdx.x,
gridDim.x * gridDim.y * gridDim.z) {}

CUTE_DEVICE
Iterator begin() const { return {start_, step_, this}; }

CUTE_DEVICE
Iterator end() const {
const int problem_tiles = params_.grid_shape_m * params_.grid_shape_n;
return {problem_tiles, 0, this};
}

// compute tile coord from linear idx
CUTE_HOST_DEVICE cute::tuple<int, int> swizzle_and_rasterize_1d(
int linear_idx) const {
int panel_idx, panel_offset;
// panel_idx = linear_idx / panel_size_;
// panel_offset = linear_idx % panel_size_;
panel_size_.divmod(linear_idx, panel_idx, panel_offset);

int minor_base = panel_idx * params_.swizzle;
const int remaining = minor_blocks_ - minor_base;
// handle last partial panel
int swizzle = std::min<int>(params_.swizzle, remaining);
// fix unaligned tma multicast
if (params_.cluster_size > 1 && (swizzle & 1)) {
const int aligned_swizzle = swizzle ^ 1;
const int aligned_panel_size = major_blocks_ * aligned_swizzle;
if (panel_offset < aligned_panel_size) {
// the index blongs to the aligned panel
swizzle = aligned_swizzle;
} else {
// the index belongs to next tiny panel
panel_idx += 1;
swizzle = 1;
panel_offset -= aligned_panel_size;
minor_base += aligned_swizzle;
}
}

// Convert linear idx within panel into cta coord
int major_idx = panel_offset / swizzle;
const int minor_idx = (panel_offset % swizzle) + minor_base;
if (panel_idx & 1) {
// odd idx within panel, reverse minor index
major_idx = (major_blocks_ - 1 - major_idx);
}

if (params_.raster_order == RasterOrder::AlongM) {
return {major_idx, minor_idx};
}
return {minor_idx, major_idx};
}

private:
Params params_;
int start_ = 0;
int step_ = 1;

// derived params for performance
int major_blocks_ = 0;
int minor_blocks_ = 0;
FastDivmod panel_size_;
};

class StaticPersistentTileScheduler2D {
public:
// Host side kernel arguments
struct Arguments {
Expand Down Expand Up @@ -90,7 +220,7 @@ class StaticPersistentTileScheduler {

CUTE_DEVICE
cute::tuple<int, int> operator*() const {
return swizzle_and_rasterize(linear_idx_, params_);
return swizzle_and_rasterize_2d(linear_idx_, params_);
}

CUTE_DEVICE
Expand All @@ -110,7 +240,7 @@ class StaticPersistentTileScheduler {
const Params& params_;
};

CUTE_DEVICE StaticPersistentTileScheduler(const Params& params)
CUTE_DEVICE StaticPersistentTileScheduler2D(const Params& params)
: params_(params) {
linear_idx_ = blockIdx.x;
grid_size_ = gridDim.x * gridDim.y * gridDim.z;
Expand All @@ -132,28 +262,28 @@ class StaticPersistentTileScheduler {
// ^ +--+--+--+--+ ^
// | |00|04|08|12| | <---- N ----> <- N ->
// C +--+--+--+--+ | <- S ->
// | |01|05|09|13| | +--+--+--+--+ ^ +--+--+ ^
// | |01|05|09|13| | +--+--+--+--+ | +--+--+ |
// v +--+--+--+--+ M ---> |00|02|04|06| | ---> |00|02| |
// |02|06|10|14| | +--+--+--+--+ M +--+--+ M
// +--+--+--+--+ | |01|03|05|07| | |01|03| |
// |03|07|11|15| | +--+--+--+--+ v +--+--+ v
// +--+--+--+--+ v
// |
// v
// ^ +--+--+--+--+
// | |00|02|12|14| <---- N ---->
// C +--+--+--+--+ <- S ->
// | |01|03|13|15| +--+--+--+--+ ^
// v +--+--+--+--+ <--- |00|01|04|05| |
// |04|06|08|10| +--+--+--+--+ M
// +--+--+--+--+ |02|03|06|07| |
// |05|07|09|11| +--+--+--+--+ v
// +--+--+--+--+
// <- S -> v
// ^ +--+--+--+--+ |
// | |00|02|12|14| | <---- N ---->
// C +--+--+--+--+ | <- S ->
// | |01|03|13|15| | +--+--+--+--+ |
// v +--+--+--+--+ M <--- |00|01|04|05| |
// |04|06|08|10| | +--+--+--+--+ M
// +--+--+--+--+ | |02|03|06|07| |
// |05|07|09|11| | +--+--+--+--+ v
// +--+--+--+--+ v
//
// Expand Cluster Expand Swizzle

// compute tile coord from linear idx
CUTE_HOST_DEVICE static cute::tuple<int, int> swizzle_and_rasterize(
CUTE_HOST_DEVICE static cute::tuple<int, int> swizzle_and_rasterize_2d(
int linear_idx,
const Params& params) {
// number of ctas per cluster
Expand Down
70 changes: 58 additions & 12 deletions src/kernels/gemm/tile_scheduler_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ class TileSchedulerTest
int32_t /*swizzle*/,
RasterOrder /*order*/>> {};

// StaticPersistentTileScheduler
TEST_P(TileSchedulerTest, StaticPersistent) {
using TileScheduler = StaticPersistentTileScheduler;
TEST_P(TileSchedulerTest, StaticPersistent2D) {
using TileScheduler = StaticPersistentTileScheduler2D;
using namespace cute;

const auto [cluster_m, cluster_n, grid_m, grid_n, swizzle, order] =
Expand All @@ -31,13 +30,15 @@ TEST_P(TileSchedulerTest, StaticPersistent) {
// make_tensor(mapping_data.data(),
// make_shape(params.grid_shape_m, params.grid_shape_n));
int pre_tile_m = 0, pre_tile_n = 0;
const int max_dist = order == RasterOrder::AlongM
? (swizzle * cluster_n) + cluster_m
: (swizzle * cluster_m) + cluster_n;
int max_dist = swizzle;
if (cluster_m > 1 || cluster_n > 1) {
max_dist = order == RasterOrder::AlongM ? (swizzle * cluster_n) + cluster_m
: (swizzle * cluster_m) + cluster_n;
}
int32_t valid = 0;
for (int linear_idx = 0; linear_idx < problem_tiles; ++linear_idx) {
const auto [tile_m, tile_n] =
TileScheduler::swizzle_and_rasterize(linear_idx, params);
TileScheduler::swizzle_and_rasterize_2d(linear_idx, params);

const int dist =
std::abs(tile_m - pre_tile_m) + std::abs(tile_n - pre_tile_n);
Expand All @@ -56,14 +57,59 @@ TEST_P(TileSchedulerTest, StaticPersistent) {
// print_tensor(mapping);
}

TEST_P(TileSchedulerTest, StaticPersistent1D) {
using TileScheduler = StaticPersistentTileScheduler;
using namespace cute;

const auto [cluster_m, cluster_n, grid_m, grid_n, swizzle, order] =
GetParam();

const int cluster_size = cluster_m * cluster_n;

// Skip test if swizzle is not a multiple of cluster size
if (swizzle % cluster_size != 0) {
return;
}

TileScheduler::Params params{cluster_size, grid_m, grid_n, swizzle, order};
TileScheduler scheduler(params, /*start=*/0, /*step=*/1);

const int problem_tiles = params.grid_shape_m * params.grid_shape_n;
// std::vector<int> mapping_data(problem_tiles);
// auto mapping =
// make_tensor(mapping_data.data(),
// make_shape(params.grid_shape_m, params.grid_shape_n));
int pre_tile_m = 0, pre_tile_n = 0;
int32_t valid = 0;
for (int linear_idx = 0; linear_idx < problem_tiles; ++linear_idx) {
const auto [tile_m, tile_n] =
scheduler.swizzle_and_rasterize_1d(linear_idx);

const int dist =
std::abs(tile_m - pre_tile_m) + std::abs(tile_n - pre_tile_n);
pre_tile_m = tile_m;
pre_tile_n = tile_n;
EXPECT_LE(dist, swizzle);
// mapping(tile_m, tile_n) = linear_idx;

// (grid_m, grid_n):(1, grid_m)
const int idx = tile_m + (tile_n * grid_m);
valid ^= idx;
valid ^= linear_idx;
}
EXPECT_EQ(valid, 0);

// print_tensor(mapping);
}

INSTANTIATE_TEST_SUITE_P(
SM80,
TileSchedulerTest,
::testing::Combine(::testing::Values(1, 2), // cluster_m
::testing::Values(1, 2), // cluster_n
::testing::Values(8, 16), // grid_m
::testing::Values(8, 16), // grid_n
::testing::Values(1, 2, 4), // swizzle
::testing::Combine(::testing::Values(1, 2), // cluster_m
::testing::Values(1, 2), // cluster_n
::testing::Values(8, 16), // grid_m
::testing::Values(8, 16), // grid_n
::testing::Values(1, 4, 8, 16), // swizzle
::testing::Values(RasterOrder::AlongM,
RasterOrder::AlongN) // order
));
Expand Down
12 changes: 2 additions & 10 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
include(cc_library)

cc_library(
NAME
NAME
cutlass
INCLUDES
cutlass/include
DEPS
DEPS
torch # TODO: depends on CUDA instead of torch
)

cc_library(
NAME
flashinfer
INCLUDES
flashinfer/include
)

add_subdirectory(sentencepiece)
if (BUILD_NVBENCH)
add_subdirectory(nvbench)
endif()