Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/kernels/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_test(
NAME
gemm_kernel_test
SRCS
tile_scheduler_test.cu
sm80_grouped_gemm_test.cu
DEPS
:gemm.kernels
Expand Down
73 changes: 73 additions & 0 deletions src/kernels/gemm/fast_math.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include <cuda.h>

#include <cute/config.hpp>

namespace llm {
struct FastDivmod {
int32_t div_ = 1;
uint32_t mul_ = 0u;
uint32_t shr_ = 0u;

CUTE_HOST_DEVICE
void reset(int div) {
div_ = div;
if (div_ != 1) {
unsigned int p = 31 + log2(div_);
unsigned m =
unsigned(((1ull << p) + unsigned(div_) - 1) / unsigned(div_));

mul_ = m;
shr_ = p - 32;
}
}

constexpr FastDivmod() = default;

CUTE_HOST_DEVICE
FastDivmod(int div) { reset(div); }

CUTE_HOST_DEVICE
FastDivmod& operator=(int div) {
reset(div);
return *this;
}

CUTE_HOST_DEVICE
void divmod(int src, int& quo, int& rem) const {
quo = div(src);
rem = src - (quo * div_);
}

CUTE_HOST_DEVICE
int div(int src) const {
#if defined(__CUDA_ARCH__)
return (div_ != 1) ? __umulhi(src, mul_) >> shr_ : src;
#else
return src / div_;
#endif
}

CUTE_HOST_DEVICE
int mod(int src) const {
#if defined(__CUDA_ARCH__)
return div_ != 1 ? src - (div(src) * div_) : 0;
#else
return src % div_;
#endif
}

CUTE_HOST_DEVICE
operator int() const { return div_; }
};

// operator overloads for FastDivmod
CUTE_HOST_DEVICE int operator/(int src, const FastDivmod& d) {
return d.div(src);
}
CUTE_HOST_DEVICE int operator%(int src, const FastDivmod& d) {
return d.mod(src);
}

} // namespace llm
160 changes: 160 additions & 0 deletions src/kernels/gemm/tile_scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
#include <cuda.h>
#include <cuda_runtime.h>

#include <cute/config.hpp>
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "fast_math.h"

namespace llm {

class SingleTileScheduler {
Expand Down Expand Up @@ -57,4 +60,161 @@ class SingleTileScheduler {
EndIterator end() const { return {}; }
};

enum class RasterOrder { AlongM, AlongN };

class StaticPersistentTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
FastDivmod cluster_shape_m = 0;
FastDivmod cluster_shape_n = 0;
int grid_shape_m = 0;
int grid_shape_n = 0;
FastDivmod swizzle = 0;
RasterOrder raster_order = RasterOrder::AlongM;
};

static dim3 get_grid_shape(Arguments const& args, int n_sms) {
return {(uint32_t)n_sms};
}

// Device side kernel params
using Params = Arguments;
static Params to_underlying_arguments(const Arguments& args) { return args; }

class Iterator {
public:
CUTE_DEVICE
Iterator(int start, int step, const Params& params)
: linear_idx_(start), step_(step), params_(params) {}

CUTE_DEVICE
cute::tuple<int, int> operator*() const {
return swizzle_and_rasterize(linear_idx_, params_);
}

CUTE_DEVICE
Iterator& operator++() {
linear_idx_ += step_;
return *this;
}

CUTE_DEVICE
bool operator!=(const Iterator& e) const {
return linear_idx_ < e.linear_idx_;
}

private:
int linear_idx_;
int step_;
const Params& params_;
};

CUTE_DEVICE StaticPersistentTileScheduler(const Params& params)
: params_(params) {
linear_idx_ = blockIdx.x;
grid_size_ = gridDim.x * gridDim.y * gridDim.z;
}

CUTE_DEVICE
Iterator begin() const { return {linear_idx_, grid_size_, params_}; }

CUTE_DEVICE
Iterator end() const {
const int problem_tiles = params_.grid_shape_m * params_.grid_shape_n;
return {problem_tiles, 0, params_};
}

// For example: ClusterShape: (2, 1), GridShape(4, 4) and Swizzle=2. Along M
//
// Reduce Cluster Reduce Swizzle
// <---- N ---->
// ^ +--+--+--+--+ ^
// | |00|04|08|12| | <---- N ----> <- N ->
// C +--+--+--+--+ | <- S ->
// | |01|05|09|13| | +--+--+--+--+ ^ +--+--+ ^
// v +--+--+--+--+ M ---> |00|02|04|06| | ---> |00|02| |
// |02|06|10|14| | +--+--+--+--+ M +--+--+ M
// +--+--+--+--+ | |01|03|05|07| | |01|03| |
// |03|07|11|15| | +--+--+--+--+ v +--+--+ v
// +--+--+--+--+ v
// |
// v
// ^ +--+--+--+--+
// | |00|02|12|14| <---- N ---->
// C +--+--+--+--+ <- S ->
// | |01|03|13|15| +--+--+--+--+ ^
// v +--+--+--+--+ <--- |00|01|04|05| |
// |04|06|08|10| +--+--+--+--+ M
// +--+--+--+--+ |02|03|06|07| |
// |05|07|09|11| +--+--+--+--+ v
// +--+--+--+--+
//
// Expand Cluster Expand Swizzle

// compute tile coord from linear idx
CUTE_HOST_DEVICE static cute::tuple<int, int> swizzle_and_rasterize(
int linear_idx,
const Params& params) {
// number of ctas per cluster
const int cluster_size = params.cluster_shape_m * params.cluster_shape_n;
// number of cluster along major
const int major_clusters =
params.raster_order == RasterOrder::AlongM
? params.grid_shape_m / params.cluster_shape_m
: params.grid_shape_n / params.cluster_shape_n;

// Convert linear CTA idx/coord into cluster coord.
// Layout: (cluster_size, clusters):(1, cluster_size)
const int cluster_idx = linear_idx / cluster_size;
const int cluster_offset = linear_idx % cluster_size;

// Convert linear idx within cluster into cta coord
// Layout: (cluster_shape_m, cluster_shape_n):(1, cluster_shape_m)
int cluster_offset_m, cluster_offset_n;
params.cluster_shape_m.divmod(
cluster_offset, cluster_offset_n, cluster_offset_m);

int major_idx, minor_idx, panel_idx;
if (params.swizzle > 1) {
// Convert cluster linear idx into swizzled coord
// Layout: (swizzle, panels): (1, swizzle)
int swizzle_idx, swizzle_offset;
params.swizzle.divmod(cluster_idx, swizzle_idx, swizzle_offset);

major_idx = swizzle_idx % major_clusters;
panel_idx = swizzle_idx / major_clusters;
minor_idx = panel_idx * params.swizzle + swizzle_offset;
} else {
// no swizzle, panel size = 1
major_idx = cluster_idx % major_clusters;
panel_idx = cluster_idx / major_clusters;
minor_idx = panel_idx;
}

if ((panel_idx & 1) != 0) {
// odd idx within panel, reverse major index
major_idx = (major_clusters - 1 - major_idx);
}

if (params.raster_order == RasterOrder::AlongM) {
// Map the swizzled cluster tile back to a CTA tile
major_idx = major_idx * params.cluster_shape_m + cluster_offset_m;
minor_idx = minor_idx * params.cluster_shape_n + cluster_offset_n;
return {major_idx, minor_idx};
}

// raster_order == AlongN
// Map the swizzled cluster tile back to a CTA tile
minor_idx = minor_idx * params.cluster_shape_m + cluster_offset_m;
major_idx = major_idx * params.cluster_shape_n + cluster_offset_n;
return {minor_idx, major_idx};
}

private:
int linear_idx_ = 0;
int grid_size_ = 0;
Params params_;
};

} // namespace llm
71 changes: 71 additions & 0 deletions src/kernels/gemm/tile_scheduler_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <gtest/gtest.h>

#include <cute/tensor.hpp>

#include "tile_scheduler.cuh"

namespace llm {

class TileSchedulerTest
: public ::testing::TestWithParam<std::tuple<int32_t /*cluster_m*/,
int32_t /*cluster_n*/,
int32_t /*grid_m*/,
int32_t /*grid_n*/,
int32_t /*swizzle*/,
RasterOrder /*order*/>> {};

// StaticPersistentTileScheduler
TEST_P(TileSchedulerTest, StaticPersistent) {
using TileScheduler = StaticPersistentTileScheduler;
using namespace cute;

const auto [cluster_m, cluster_n, grid_m, grid_n, swizzle, order] =
GetParam();

TileScheduler::Params params{
cluster_m, cluster_n, grid_m, grid_n, swizzle, order};

const int problem_tiles = params.grid_shape_m * params.grid_shape_n;
// std::vector<int> mapping_data(problem_tiles);
// auto mapping =
// make_tensor(mapping_data.data(),
// make_shape(params.grid_shape_m, params.grid_shape_n));
int pre_tile_m = 0, pre_tile_n = 0;
const int max_dist = order == RasterOrder::AlongM
? (swizzle * cluster_n) + cluster_m
: (swizzle * cluster_m) + cluster_n;
int32_t valid = 0;
for (int linear_idx = 0; linear_idx < problem_tiles; ++linear_idx) {
const auto [tile_m, tile_n] =
TileScheduler::swizzle_and_rasterize(linear_idx, params);

const int dist =
std::abs(tile_m - pre_tile_m) + std::abs(tile_n - pre_tile_n);
pre_tile_m = tile_m;
pre_tile_n = tile_n;
EXPECT_LE(dist, max_dist);
// mapping(tile_m, tile_n) = linear_idx;

// (grid_m, grid_n):(1, grid_m)
const int idx = tile_m + (tile_n * grid_m);
valid ^= idx;
valid ^= linear_idx;
}
EXPECT_EQ(valid, 0);

// print_tensor(mapping);
}

INSTANTIATE_TEST_SUITE_P(
TileScheduler,
TileSchedulerTest,
::testing::Combine(::testing::Values(1, 2), // cluster_m
::testing::Values(1, 2), // cluster_n
::testing::Values(8, 16), // grid_m
::testing::Values(8, 16), // grid_n
::testing::Values(1, 2, 4), // swizzle
::testing::Values(RasterOrder::AlongM,
RasterOrder::AlongN) // order
));

} // namespace llm