Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/common/gtest_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace {
// https://github.com/NVIDIA/cutlass/blob/main/test/unit/common/filter_architecture.cpp#L78

// generate gtest filter string based on device compute capability
std::string gen_gtest_filter() {
std::string gen_gtest_filter(const std::string& cmd_filter) {
static const int kMaxComputeCapability = 10000;
// Text filters for each kernel based on supported compute capability
struct Filter {
Expand Down Expand Up @@ -60,16 +60,19 @@ std::string gen_gtest_filter() {
// add separator if not the first filter
ss << (i++ ? ":" : "") << filter.filter;
}
return ss.str();
if (cmd_filter.empty()) {
// If no cmd filter, return the negative filters
return ss.str();
}
// If cmd filter is present, append the negative filters
// to the existing filter
return cmd_filter + ":" + ss.str();
}
} // namespace

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
// honor --gtest_filter from commandline
if (::testing::GTEST_FLAG(filter).empty() ||
::testing::GTEST_FLAG(filter) == "*") {
::testing::GTEST_FLAG(filter) = gen_gtest_filter();
}
const auto cmd_filter = ::testing::GTEST_FLAG(filter);
::testing::GTEST_FLAG(filter) = gen_gtest_filter(cmd_filter);
return RUN_ALL_TESTS();
}
11 changes: 11 additions & 0 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,15 @@ nvbench_binary(
:attention.template
)

cc_test(
NAME
tma_test
SRCS
sm120_tma_pagedkv_test.cu
DEPS
:gtest_main
:cutlass
torch
)

add_subdirectory(tools)
82 changes: 82 additions & 0 deletions src/kernels/attention/gather_tma_tensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#pragma once

// #include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/config.hpp>
#include <cute/int_tuple.hpp>
#include <cute/numeric/arithmetic_tuple.hpp>

namespace cute {

//
// A arithmetic tuple iterator with a coordinate transform.
//
template <class ArithTuple, class Transform>
struct GatherArithmeticTupleIterator {
using value_type = ArithTuple;
using element_type = ArithTuple;
using reference = ArithTuple;

ArithTuple coord_;
Transform transform_;

CUTE_HOST_DEVICE constexpr GatherArithmeticTupleIterator(
const ArithTuple& coord,
const Transform& transform)
: coord_(coord), transform_(transform) {}

CUTE_HOST_DEVICE constexpr auto coord() const {
// apply the transform to the coordinate
return transform_(coord_);
}

CUTE_HOST_DEVICE constexpr auto operator*() const { return coord(); }

template <class Coord>
CUTE_HOST_DEVICE constexpr auto operator+(const Coord& c) const {
auto coord = coord_ + c;
return GatherArithmeticTupleIterator<remove_cvref_t<decltype(coord)>,
Transform>(coord, transform_);
}
};

template <class Tuple, class Transform>
CUTE_HOST_DEVICE constexpr auto make_gather_inttuple_iter(
const Tuple& t,
const Transform& transform) {
return GatherArithmeticTupleIterator(as_arithmetic_tuple(t), transform);
}

// Generate the TMA coord tensor with transform
template <class TMA, class GShape, class Transform>
CUTE_HOST_DEVICE constexpr auto make_gather_tma_tensor(
const TMA& tma,
const GShape& g_shape,
const Transform& transform) {
static_assert(is_congruent<decltype(g_shape),
decltype(tma.aux_params_.g_stride_)>::value);
auto layout = make_layout(g_shape, tma.aux_params_.g_stride_);
return make_tensor(make_gather_inttuple_iter(coprofile(layout), transform),
layout);
}

//
// Display utilities
//

template <class ArithTuple, class Transform>
CUTE_HOST_DEVICE void print(
const GatherArithmeticTupleIterator<ArithTuple, Transform>& iter) {
printf("GatherArithTuple");
print(iter.coord());
}

#if !defined(__CUDACC_RTC__)
template <class ArithTuple, class Transform>
CUTE_HOST std::ostream& operator<<(
std::ostream& os,
const GatherArithmeticTupleIterator<ArithTuple, Transform>& iter) {
return os << "GatherArithTuple" << iter.coord();
}
#endif

} // end namespace cute
Loading