diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index c2a9d491..df7fd7e9 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -124,7 +124,7 @@ cc_test( NAME tma_test SRCS - sm120_tma_pagedkv_test.cu + sm120_tma_block_load_test.cu DEPS :gtest_main :cutlass diff --git a/src/kernels/attention/gather_tma_copy.h b/src/kernels/attention/gather_tma_copy.h new file mode 100644 index 00000000..a8be0586 --- /dev/null +++ b/src/kernels/attention/gather_tma_copy.h @@ -0,0 +1,106 @@ +#pragma once +#include + +namespace cute { + +namespace detail { + +template +CUTE_HOST_RTC auto get_tma_atom_slayout( + SLayout const& slayout // ((ATOM_M, m), (ATOM_N, n)) +) { + return composition(slayout.layout_a(), + slayout.offset(), + make_layout(get<0, 0>(slayout.layout_b()), + get<1, 0>(slayout.layout_b()))); +} + +template +CUTE_HOST_RTC auto make_gather_tma_copy_atom( + CopyOp copy_op, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + uint32_t const& num_multicast) // The num of CTAs involved in multicasting) +{ + auto atom_slayout = get_tma_atom_slayout(slayout); + auto atom_tiler = product_each(shape(atom_slayout)); + auto atom_v_map = make_identity_layout(shape(gtensor)).compose(atom_tiler); + return make_tma_copy_atom( + copy_op, gtensor, atom_slayout, num_multicast, atom_v_map); +} + +template +CUTE_HOST_RTC auto make_gather_tma_copy_tiled( + CopyOp const& copy_op, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM + Layout const& cta_t_map, // T: Thr idx -> logical TMA tid + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + // Construct tma copy atom + auto atom = make_gather_tma_copy_atom( + copy_op, gtensor, slayout, cosize(cta_t_map)); + + // Construct the TiledCopy + [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = + size<1>(typename decltype(atom)::RefLayout{}) / + static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), + safe_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = + composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + return TiledCopy{ + atom}; +} + +} // namespace detail + +template +CUTE_HOST_RTC auto make_gather_tma_copy( + CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, // ((ATOM_M, m), (ATOM_N, n)) + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) { + // Thr idx -> logical TMA tid + auto cta_t_map = make_layout(cluster_size); + // CTA val idx -> gmem mode + auto cta_v_map = make_identity_layout(shape(gtensor)).compose(cta_tiler); + using TmaType = conditional_t, + typename GEngine::value_type, + TmaInternalType>; + return detail::make_gather_tma_copy_tiled( + copy_op, gtensor, slayout, cta_t_map, cta_v_map); +} +} // namespace cute diff --git a/src/kernels/attention/sm120_tma_pagedkv_test.cu b/src/kernels/attention/sm120_tma_block_load_test.cu similarity index 95% rename from src/kernels/attention/sm120_tma_pagedkv_test.cu rename to src/kernels/attention/sm120_tma_block_load_test.cu index a5c51486..78a6117f 100644 --- a/src/kernels/attention/sm120_tma_pagedkv_test.cu +++ b/src/kernels/attention/sm120_tma_block_load_test.cu @@ -6,6 +6,7 @@ #include #include "cute/swizzle_layout.hpp" +#include "gather_tma_copy.h" #include "gather_tma_tensor.hpp" namespace llm { @@ -110,11 +111,6 @@ __global__ void tma_test_device_cute(T* g_out, sizeof(make_tensor_like(tensor<0>(tAsA))); if (threadIdx.x == 0) { - // print("\n ########### %d ########### \n", stage); - // print("sA: "); - // print(sA); - // print("\n"); - /// Initialize shared memory barrier tma_load_mbar[0] = 0; cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); @@ -191,8 +187,8 @@ auto test_tma_block_load(CopyOp const& copy_op, // Create TMA for this device Tensor Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); - auto tma = - make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + auto tma = make_gather_tma_copy( + copy_op, gA, smem_layout, cta_tile, Int<1>{}); // Launch int smem_size = int(sizeof(SharedStorage)); @@ -266,15 +262,12 @@ auto test_tma_block_load(GMEM_Layout const& gmem_layout, template typename SWIZZLE_ATOM> auto test_tma_block_load_swizzle_tile_k(int32_t block_size) { - auto gmem_layout = make_layout(make_shape(256, 256), GenRowMajor{}); + auto gmem_layout = make_layout(make_shape(256, 128), GenRowMajor{}); - auto gather_shape = Shape<_64, _256>{}; + auto gather_shape = Shape<_128, _128>{}; auto gather_gmem_layout = make_layout(gather_shape, GenRowMajor{}); auto smem_layout = tile_to_shape(SWIZZLE_ATOM{}, gather_shape, Step<_1, _0>{}); - - // TODO: fix the test failures related to tma box size - // assert (size<1>(SWIZZLE_ATOM{}) != size<1>(smem_layout)); return test_tma_block_load( gmem_layout, gather_gmem_layout, smem_layout, block_size); } @@ -299,6 +292,9 @@ auto test_tma_block_load(int32_t block_size) { block_size); } -TEST(SM120_Tma, Test_Tma_Block_Load) { test_tma_block_load(/*block_size=*/8); } +TEST(SM120_Tma, Test_Tma_Block_Load) { + test_tma_block_load(/*block_size=*/8); + test_tma_block_load(/*block_size=*/16); +} } // namespace llm