Skip to content

Commit ef50d7b

Browse files
authored
feat: add tma copy for paged kv (#480)
1 parent c1edc65 commit ef50d7b

File tree

4 files changed

+407
-7
lines changed

4 files changed

+407
-7
lines changed

src/common/gtest_main.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace {
88
// https://github.com/NVIDIA/cutlass/blob/main/test/unit/common/filter_architecture.cpp#L78
99

1010
// generate gtest filter string based on device compute capability
11-
std::string gen_gtest_filter() {
11+
std::string gen_gtest_filter(const std::string& cmd_filter) {
1212
static const int kMaxComputeCapability = 10000;
1313
// Text filters for each kernel based on supported compute capability
1414
struct Filter {
@@ -60,16 +60,19 @@ std::string gen_gtest_filter() {
6060
// add separator if not the first filter
6161
ss << (i++ ? ":" : "") << filter.filter;
6262
}
63-
return ss.str();
63+
if (cmd_filter.empty()) {
64+
// If no cmd filter, return the negative filters
65+
return ss.str();
66+
}
67+
// If cmd filter is present, append the negative filters
68+
// to the existing filter
69+
return cmd_filter + ":" + ss.str();
6470
}
6571
} // namespace
6672

6773
int main(int argc, char** argv) {
6874
::testing::InitGoogleTest(&argc, argv);
69-
// honor --gtest_filter from commandline
70-
if (::testing::GTEST_FLAG(filter).empty() ||
71-
::testing::GTEST_FLAG(filter) == "*") {
72-
::testing::GTEST_FLAG(filter) = gen_gtest_filter();
73-
}
75+
const auto cmd_filter = ::testing::GTEST_FLAG(filter);
76+
::testing::GTEST_FLAG(filter) = gen_gtest_filter(cmd_filter);
7477
return RUN_ALL_TESTS();
7578
}

src/kernels/attention/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,15 @@ nvbench_binary(
120120
:attention.template
121121
)
122122

123+
cc_test(
124+
NAME
125+
tma_test
126+
SRCS
127+
sm120_tma_pagedkv_test.cu
128+
DEPS
129+
:gtest_main
130+
:cutlass
131+
torch
132+
)
133+
123134
add_subdirectory(tools)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#pragma once
2+
3+
// #include <cute/algorithm/tuple_algorithms.hpp>
4+
#include <cute/config.hpp>
5+
#include <cute/int_tuple.hpp>
6+
#include <cute/numeric/arithmetic_tuple.hpp>
7+
8+
namespace cute {
9+
10+
//
11+
// A arithmetic tuple iterator with a coordinate transform.
12+
//
13+
template <class ArithTuple, class Transform>
14+
struct GatherArithmeticTupleIterator {
15+
using value_type = ArithTuple;
16+
using element_type = ArithTuple;
17+
using reference = ArithTuple;
18+
19+
ArithTuple coord_;
20+
Transform transform_;
21+
22+
CUTE_HOST_DEVICE constexpr GatherArithmeticTupleIterator(
23+
const ArithTuple& coord,
24+
const Transform& transform)
25+
: coord_(coord), transform_(transform) {}
26+
27+
CUTE_HOST_DEVICE constexpr auto coord() const {
28+
// apply the transform to the coordinate
29+
return transform_(coord_);
30+
}
31+
32+
CUTE_HOST_DEVICE constexpr auto operator*() const { return coord(); }
33+
34+
template <class Coord>
35+
CUTE_HOST_DEVICE constexpr auto operator+(const Coord& c) const {
36+
auto coord = coord_ + c;
37+
return GatherArithmeticTupleIterator<remove_cvref_t<decltype(coord)>,
38+
Transform>(coord, transform_);
39+
}
40+
};
41+
42+
template <class Tuple, class Transform>
43+
CUTE_HOST_DEVICE constexpr auto make_gather_inttuple_iter(
44+
const Tuple& t,
45+
const Transform& transform) {
46+
return GatherArithmeticTupleIterator(as_arithmetic_tuple(t), transform);
47+
}
48+
49+
// Generate the TMA coord tensor with transform
50+
template <class TMA, class GShape, class Transform>
51+
CUTE_HOST_DEVICE constexpr auto make_gather_tma_tensor(
52+
const TMA& tma,
53+
const GShape& g_shape,
54+
const Transform& transform) {
55+
static_assert(is_congruent<decltype(g_shape),
56+
decltype(tma.aux_params_.g_stride_)>::value);
57+
auto layout = make_layout(g_shape, tma.aux_params_.g_stride_);
58+
return make_tensor(make_gather_inttuple_iter(coprofile(layout), transform),
59+
layout);
60+
}
61+
62+
//
63+
// Display utilities
64+
//
65+
66+
template <class ArithTuple, class Transform>
67+
CUTE_HOST_DEVICE void print(
68+
const GatherArithmeticTupleIterator<ArithTuple, Transform>& iter) {
69+
printf("GatherArithTuple");
70+
print(iter.coord());
71+
}
72+
73+
#if !defined(__CUDACC_RTC__)
74+
template <class ArithTuple, class Transform>
75+
CUTE_HOST std::ostream& operator<<(
76+
std::ostream& os,
77+
const GatherArithmeticTupleIterator<ArithTuple, Transform>& iter) {
78+
return os << "GatherArithTuple" << iter.coord();
79+
}
80+
#endif
81+
82+
} // end namespace cute

0 commit comments

Comments
 (0)