diff --git a/src/common/gtest_main.cpp b/src/common/gtest_main.cpp index c9b9bf2d..3dcbea86 100644 --- a/src/common/gtest_main.cpp +++ b/src/common/gtest_main.cpp @@ -8,7 +8,7 @@ namespace { // https://github.com/NVIDIA/cutlass/blob/main/test/unit/common/filter_architecture.cpp#L78 // generate gtest filter string based on device compute capability -std::string gen_gtest_filter() { +std::string gen_gtest_filter(const std::string& cmd_filter) { static const int kMaxComputeCapability = 10000; // Text filters for each kernel based on supported compute capability struct Filter { @@ -60,16 +60,19 @@ std::string gen_gtest_filter() { // add separator if not the first filter ss << (i++ ? ":" : "") << filter.filter; } - return ss.str(); + if (cmd_filter.empty()) { + // If no cmd filter, return the negative filters + return ss.str(); + } + // If cmd filter is present, append the negative filters + // to the existing filter + return cmd_filter + ":" + ss.str(); } } // namespace int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); - // honor --gtest_filter from commandline - if (::testing::GTEST_FLAG(filter).empty() || - ::testing::GTEST_FLAG(filter) == "*") { - ::testing::GTEST_FLAG(filter) = gen_gtest_filter(); - } + const auto cmd_filter = ::testing::GTEST_FLAG(filter); + ::testing::GTEST_FLAG(filter) = gen_gtest_filter(cmd_filter); return RUN_ALL_TESTS(); } diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 04f1df9b..c2a9d491 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -120,4 +120,15 @@ nvbench_binary( :attention.template ) +cc_test( + NAME + tma_test + SRCS + sm120_tma_pagedkv_test.cu + DEPS + :gtest_main + :cutlass + torch +) + add_subdirectory(tools) diff --git a/src/kernels/attention/gather_tma_tensor.hpp b/src/kernels/attention/gather_tma_tensor.hpp new file mode 100644 index 00000000..3dc29ccc --- /dev/null +++ b/src/kernels/attention/gather_tma_tensor.hpp @@ -0,0 +1,82 @@ +#pragma once + +// #include +#include +#include +#include + +namespace cute { + +// +// A arithmetic tuple iterator with a coordinate transform. +// +template +struct GatherArithmeticTupleIterator { + using value_type = ArithTuple; + using element_type = ArithTuple; + using reference = ArithTuple; + + ArithTuple coord_; + Transform transform_; + + CUTE_HOST_DEVICE constexpr GatherArithmeticTupleIterator( + const ArithTuple& coord, + const Transform& transform) + : coord_(coord), transform_(transform) {} + + CUTE_HOST_DEVICE constexpr auto coord() const { + // apply the transform to the coordinate + return transform_(coord_); + } + + CUTE_HOST_DEVICE constexpr auto operator*() const { return coord(); } + + template + CUTE_HOST_DEVICE constexpr auto operator+(const Coord& c) const { + auto coord = coord_ + c; + return GatherArithmeticTupleIterator, + Transform>(coord, transform_); + } +}; + +template +CUTE_HOST_DEVICE constexpr auto make_gather_inttuple_iter( + const Tuple& t, + const Transform& transform) { + return GatherArithmeticTupleIterator(as_arithmetic_tuple(t), transform); +} + +// Generate the TMA coord tensor with transform +template +CUTE_HOST_DEVICE constexpr auto make_gather_tma_tensor( + const TMA& tma, + const GShape& g_shape, + const Transform& transform) { + static_assert(is_congruent::value); + auto layout = make_layout(g_shape, tma.aux_params_.g_stride_); + return make_tensor(make_gather_inttuple_iter(coprofile(layout), transform), + layout); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print( + const GatherArithmeticTupleIterator& iter) { + printf("GatherArithTuple"); + print(iter.coord()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<( + std::ostream& os, + const GatherArithmeticTupleIterator& iter) { + return os << "GatherArithTuple" << iter.coord(); +} +#endif + +} // end namespace cute diff --git a/src/kernels/attention/sm120_tma_pagedkv_test.cu b/src/kernels/attention/sm120_tma_pagedkv_test.cu new file mode 100644 index 00000000..a5c51486 --- /dev/null +++ b/src/kernels/attention/sm120_tma_pagedkv_test.cu @@ -0,0 +1,304 @@ +#include +#include +#include + +#include +#include + +#include "cute/swizzle_layout.hpp" +#include "gather_tma_tensor.hpp" + +namespace llm { +using namespace cute; + +template +struct SharedStorage { + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +template +__global__ void tma_test_device_cute(T* g_out, + int const* block_ids, + int block_size, + CUTE_GRID_CONSTANT TiledCopy const tma, + CTA_Tiler cta_tiler, + GmemLayout gmem_layout, + SmemLayout smem_layout) { + using namespace cute; + + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == + product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM + // addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = + *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + // (CTA_TILE_M,CTA_TILE_N,...) + Tensor sA = + make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); + Tensor sA_no_siwzzle = make_tensor(make_smem_ptr(shared_storage.smem.begin()), + get_nonswizzle_portion(smem_layout)); + + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain + // mapping Represent the full tensors -- get these from TMA + // Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + + auto coord_transform = [block_ids, block_size](auto coord) { + constexpr int I = 1; + const int idx = get(coord); + const int blk_idx = idx / block_size; + const int blk_offset = idx % block_size; + const int g_idx = (block_ids[blk_idx] * block_size) + blk_offset; + // print("mapping: %d => %d\n", idx, g_idx); + // return replace(coord, g_idx); + return replace(coord, g_idx); + }; + + // (m, n) => (1@0, 1@1) + Tensor mA = make_gather_tma_tensor(tma, shape(gmem_layout), coord_transform); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + constexpr int R = rank_v; + // (CTA_TILE_M,CTA_TILE_N, REST_M,REST_N,...) + Tensor gA = flat_divide(mA, cta_tiler); + // (CTA_TILE_M,CTA_TILE_N, REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); + + // + // Prepare the TMA_LOAD + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + // (TMA,TMA_M,TMA_N,REST_M,REST_N) + Tensor tAgA_x = cta_tma.partition_S(gA); + // (TMA,TMA_M,TMA_N) + Tensor tAsA_x = cta_tma.partition_D(sA); + + // + // Perform the TMA_LOAD + // + + // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through + // the tiles + // (TMA,REST) + Tensor tAgA = group_modes<1, rank(tAgA_x)>(tAgA_x); + Tensor tAsA = group_modes<1, rank(tAsA_x)>(tAsA_x); + static_assert(size<1>(tAsA) == 1); + + // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output + // (CTA_TILE, REST) + Tensor tBgB = group_modes<0, R>(group_modes(gB)); + + // Loop over the TMA stages, using smem as our buffer + // (TMA,REST) + for (int stage = 0; stage < size<1>(tAgA); ++stage) { + // Set the bytes transferred in this TMA transaction (may involve multiple + // issues) + constexpr int kTmaTransactionBytes = + sizeof(make_tensor_like(tensor<0>(tAsA))); + + if (threadIdx.x == 0) { + // print("\n ########### %d ########### \n", stage); + // print("sA: "); + // print(sA); + // print("\n"); + + /// Initialize shared memory barrier + tma_load_mbar[0] = 0; + cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); + cute::set_barrier_transaction_bytes(tma_load_mbar[0], + kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0]), tAgA(_, stage), tAsA(_, 0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from + /// kPhaseBit value + constexpr int kPhaseBit = 0; + cute::wait_barrier(tma_load_mbar[0], kPhaseBit); + + // Subbyte elements could cause race conditions, so be even more + // conservative + if (thread0()) { + copy(sA, tBgB(_, stage)); + } + + __syncthreads(); + } +} + +template +auto test_tma_block_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + G_GMEM_Layout const& gather_gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile, + int32_t block_size) { + assert(block_size % 8 == 0); + const int m_gather = size<0>(gather_gmem_layout); + assert(m_gather % block_size == 0); + + const int32_t n_blocks = m_gather / block_size; + const int32_t n_slots = n_blocks * block_size; + // generate blocks + std::vector block_ids; + std::vector slot_ids; + block_ids.reserve(n_blocks); + slot_ids.reserve(n_slots); + for (int i = 0; i < n_blocks; ++i) { + const int blk_id = i ^ 1; + block_ids.push_back(blk_id); + const int32_t slot_base = blk_id * block_size; + for (int32_t j = 0; j < block_size; ++j) { + slot_ids.push_back(slot_base + j); + } + } + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = + make_tensor(make_gmem_ptr(raw_pointer_cast(h_in.data())), gmem_layout); + + // Allocate and initialize device test data + size_t GN = ceil_div(cosize(gather_gmem_layout) * sizeof_bits::value, 8); + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(GN, uint8_t(-1)); // overflow uint + thrust::device_vector d_block_ids = block_ids; + + // Create TMA for this device Tensor + Tensor gA = + make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = + make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_out.data())), + reinterpret_cast(raw_pointer_cast(d_block_ids.data())), + block_size, + tma, + cta_tile, + gather_gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(make_gmem_ptr(raw_pointer_cast(h_out.data())), + gather_gmem_layout); + + thrust::host_vector h_out_ref(GN, uint8_t(-1)); + Tensor hA_out_ref = make_tensor( + make_gmem_ptr(raw_pointer_cast(h_out_ref.data())), gather_gmem_layout); + for (int i = 0; i < slot_ids.size(); ++i) { + cute::copy(hA_in(slot_ids[i], _), hA_out_ref(i, _)); + } + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_out_ref(i), hA_out(i)); + if (hA_out_ref(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +template +auto test_tma_block_load(GMEM_Layout const& gmem_layout, + G_GMEM_Layout const& gather_gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile, + int32_t block_size) { + return test_tma_block_load(SM90_TMA_LOAD{}, + gmem_layout, + gather_gmem_layout, + smem_layout, + cta_tile, + block_size); +} + +template +auto test_tma_block_load(GMEM_Layout const& gmem_layout, + G_GMEM_Layout const& gather_gmem_layout, + SMEM_Layout const& smem_layout, + int32_t block_size) { + return test_tma_block_load(gmem_layout, + gather_gmem_layout, + smem_layout, + product_each(shape(smem_layout)), + block_size); +} + +template typename SWIZZLE_ATOM> +auto test_tma_block_load_swizzle_tile_k(int32_t block_size) { + auto gmem_layout = make_layout(make_shape(256, 256), GenRowMajor{}); + + auto gather_shape = Shape<_64, _256>{}; + auto gather_gmem_layout = make_layout(gather_shape, GenRowMajor{}); + auto smem_layout = + tile_to_shape(SWIZZLE_ATOM{}, gather_shape, Step<_1, _0>{}); + + // TODO: fix the test failures related to tma box size + // assert (size<1>(SWIZZLE_ATOM{}) != size<1>(smem_layout)); + return test_tma_block_load( + gmem_layout, gather_gmem_layout, smem_layout, block_size); +} + +auto test_tma_block_load(int32_t block_size) { + test_tma_block_load_swizzle_tile_k( + block_size); + test_tma_block_load_swizzle_tile_k( + block_size); + test_tma_block_load_swizzle_tile_k( + block_size); + test_tma_block_load_swizzle_tile_k( + block_size); + + test_tma_block_load_swizzle_tile_k( + block_size); + test_tma_block_load_swizzle_tile_k( + block_size); + test_tma_block_load_swizzle_tile_k( + block_size); + test_tma_block_load_swizzle_tile_k( + block_size); +} + +TEST(SM120_Tma, Test_Tma_Block_Load) { test_tma_block_load(/*block_size=*/8); } + +} // namespace llm