Skip to content

Commit 8f30615

Browse files
authored
kernel: add grouped gemm support for moe (#458)
1 parent e38b313 commit 8f30615

File tree

6 files changed

+822
-9
lines changed

6 files changed

+822
-9
lines changed

src/kernels/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
include(cc_library)
22

33
cc_library(
4-
NAME
4+
NAME
55
kernels
6-
HDRS
6+
HDRS
77
reduce_kernel_utils.cuh
88
activation_kernels.h
99
layernorm_kernels.h
1010
pos_embedding_kernels.h
1111
kv_cache_kernels.h
1212
sampling/sampling_kernels.h
13-
SRCS
13+
SRCS
1414
activation_kernels.cu
1515
layernorm_kernels.cu
1616
pos_embedding_kernels.cu
@@ -28,7 +28,7 @@ cc_library(
2828

2929
add_subdirectory(attention)
3030
add_subdirectory(moe)
31+
add_subdirectory(gemm)
3132
add_subdirectory(quantization)
3233
add_subdirectory(playground)
33-
add_subdirectory(triton)
34-
34+
# add_subdirectory(triton)

src/kernels/gemm/CMakeLists.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
include(cc_library)
2+
include(cc_test)
3+
4+
cc_library(
5+
NAME
6+
gemm.kernels
7+
HDRS
8+
grouped_gemm_kernel_sm80.cuh
9+
DEPS
10+
cutlass
11+
)
12+
13+
14+
cc_test(
15+
NAME
16+
gemm_kernel_test
17+
SRCS
18+
grouped_gemm_kernel_sm80_test.cu
19+
DEPS
20+
:gemm.kernels
21+
absl::random_random
22+
GTest::gtest_main
23+
torch
24+
)

src/kernels/gemm/gather_tensor.hpp

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// adapted from
2+
// https://github.com/NVIDIA/cutlass/blob/main/examples/common/gather_tensor.hpp
3+
#pragma once
4+
5+
#include "cute/layout.hpp"
6+
#include "cute/layout_composed.hpp"
7+
#include "cute/tensor.hpp"
8+
namespace llm {
9+
10+
using namespace cute;
11+
12+
namespace detail {
13+
14+
// every stride must be divisible by div
15+
template <class Stride, class Div>
16+
CUTE_HOST_DEVICE constexpr auto safe_stride_div(Stride const& s,
17+
const Div& div) {
18+
if constexpr (is_tuple<Stride>::value) {
19+
return transform(s, [&](auto const& a) { return safe_stride_div(a, div); });
20+
} else {
21+
return safe_div(s, div);
22+
}
23+
CUTE_GCC_UNREACHABLE;
24+
}
25+
26+
} // namespace detail
27+
28+
/// Custom stride object that applies a function followed by a stride
29+
template <class Func, class Stride>
30+
struct CustomStride {
31+
CUTE_HOST_DEVICE constexpr CustomStride(const Func& func,
32+
const Stride& stride)
33+
: func_(func), stride_(stride) {}
34+
35+
template <class I>
36+
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, const CustomStride& s) {
37+
return inner_product(s.func_(i), s.stride_);
38+
}
39+
40+
template <class I>
41+
CUTE_HOST_DEVICE constexpr friend auto operator*(const CustomStride& s, I i) {
42+
return inner_product(s.func_(i), s.stride_);
43+
}
44+
45+
template <class Div>
46+
CUTE_HOST_DEVICE constexpr friend auto safe_div(const CustomStride& s,
47+
const Div& div) {
48+
auto stride = detail::safe_stride_div(s.stride_, div);
49+
return CustomStride<Func, decltype(stride)>(s.func_, stride);
50+
}
51+
52+
template <class Shape>
53+
CUTE_HOST_DEVICE constexpr friend auto make_layout(
54+
const Shape& shape,
55+
const CustomStride& stride) {
56+
return Layout<Shape, CustomStride>(shape, stride);
57+
}
58+
59+
CUTE_HOST_DEVICE friend void print(CustomStride const& s) {
60+
print("CustomStride{func,");
61+
print(s.stride_);
62+
print("}");
63+
}
64+
65+
Func func_;
66+
Stride stride_;
67+
};
68+
69+
template <class Func, class Shape, class Stride>
70+
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Func&& func,
71+
const Shape& shape,
72+
const Stride& stride) {
73+
// Use a dummy shape and replace the first non-unit stride with a custom
74+
// gather stride
75+
auto idx =
76+
find_if(stride, [](auto x) { return not is_constant<1, decltype(x)>{}; });
77+
constexpr int I = decltype(idx)::value;
78+
return make_layout(
79+
repeat_like(shape, _1{}),
80+
replace<I>(stride,
81+
CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
82+
}
83+
84+
/// Helper function to optionally create a gather tensor
85+
template <class Iterator, class Shape, class Stride, class Func>
86+
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter,
87+
const Shape& shape,
88+
const Stride& stride,
89+
Func&& func) {
90+
Layout matrix_layout = make_identity_layout(shape);
91+
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
92+
Layout gather_layout =
93+
make_custom_stride_layout(static_cast<Func&&>(func), shape, stride);
94+
return make_tensor(iter,
95+
ComposedLayout{gather_layout, offset, matrix_layout});
96+
}
97+
98+
} // namespace llm
99+
100+
namespace cute {
101+
102+
template <int N, int I, class Shape, class Stride>
103+
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape,
104+
Stride const& stride) {
105+
if constexpr (is_tuple<Shape>::value) {
106+
return transform_layout(shape, stride, [](auto const& s, auto const& d) {
107+
return upcast<N, I>(s, d);
108+
});
109+
} else if constexpr (is_scaled_basis<Stride>::value) {
110+
if constexpr (Stride::mode() == I) {
111+
return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));
112+
} else {
113+
return make_layout(shape, stride);
114+
}
115+
} else {
116+
return upcast<N>(shape, stride);
117+
}
118+
119+
CUTE_GCC_UNREACHABLE;
120+
}
121+
122+
template <int N,
123+
class OuterShape,
124+
class OuterStride,
125+
class Offset,
126+
class Shape,
127+
class Stride>
128+
CUTE_HOST_DEVICE constexpr auto upcast(
129+
ComposedLayout<Layout<OuterShape, OuterStride>,
130+
Offset,
131+
Layout<Shape, Stride>> const& layout) {
132+
// Find index of the stride-1 mode - that is the only one that requires
133+
// updating inner shape and offset
134+
auto idx = find_if(layout.layout_a().stride(),
135+
[](auto x) { return is_constant<1, decltype(x)>{}; });
136+
constexpr int I = decltype(idx)::value;
137+
138+
// Upcast the outer layout (works as expected)
139+
auto outer = upcast<N>(layout.layout_a());
140+
141+
// Upcast the accumulated offset along stride-1 mode
142+
auto offset = as_arithmetic_tuple(
143+
replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
144+
145+
// Upcast the inner layout's shape along stride-1 mode
146+
auto inner =
147+
upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());
148+
149+
return composition(outer, offset, inner);
150+
}
151+
152+
template <class ShapeA,
153+
class StrideA,
154+
class OuterShapeB,
155+
class OuterStrideB,
156+
class OffsetB,
157+
class ShapeB,
158+
class StrideB>
159+
CUTE_HOST_DEVICE constexpr auto max_common_vector(
160+
Layout<ShapeA, StrideA> const& a,
161+
ComposedLayout<Layout<OuterShapeB, OuterStrideB>,
162+
OffsetB,
163+
Layout<ShapeB, StrideB>> const& b) {
164+
return max_common_vector(b.layout_b(), a);
165+
}
166+
167+
} // namespace cute

0 commit comments

Comments
 (0)