Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ cc_test(
NAME
tma_test
SRCS
sm120_tma_pagedkv_test.cu
sm120_tma_block_load_test.cu
DEPS
:gtest_main
:cutlass
Expand Down
106 changes: 106 additions & 0 deletions src/kernels/attention/gather_tma_copy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#pragma once
#include <cute/atom/copy_traits_sm90_tma.hpp>

namespace cute {

namespace detail {

template <class SLayout>
CUTE_HOST_RTC auto get_tma_atom_slayout(
SLayout const& slayout // ((ATOM_M, m), (ATOM_N, n))
) {
return composition(slayout.layout_a(),
slayout.offset(),
make_layout(get<0, 0>(slayout.layout_b()),
get<1, 0>(slayout.layout_b())));
}

template <class TmaInternalType,
class CopyOp,
class GEngine,
class GLayout,
class SLayout>
CUTE_HOST_RTC auto make_gather_tma_copy_atom(
CopyOp copy_op,
Tensor<GEngine, GLayout> const& gtensor, // Full GMEM Tensor
SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled
uint32_t const& num_multicast) // The num of CTAs involved in multicasting)
{
auto atom_slayout = get_tma_atom_slayout(slayout);
auto atom_tiler = product_each(shape(atom_slayout));
auto atom_v_map = make_identity_layout(shape(gtensor)).compose(atom_tiler);
return make_tma_copy_atom<TmaInternalType>(
copy_op, gtensor, atom_slayout, num_multicast, atom_v_map);
}

template <class TmaType,
class CopyOp,
class GEngine,
class GLayout,
class SLayout,
class TShape,
class TStride,
class VShape,
class VStride>
CUTE_HOST_RTC auto make_gather_tma_copy_tiled(
CopyOp const& copy_op,
Tensor<GEngine, GLayout> const& gtensor, // Full GMEM Tensor
SLayout const& slayout, // CTA Tile of SMEM
Layout<TShape, TStride> const& cta_t_map, // T: Thr idx -> logical TMA tid
Layout<VShape, VStride> const& cta_v_map) // V: CTA val idx -> gmem mode
{
// Construct tma copy atom
auto atom = make_gather_tma_copy_atom<TmaType>(
copy_op, gtensor, slayout, cosize(cta_t_map));

// Construct the TiledCopy
[[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map));

auto num_elems_per_tma =
size<1>(typename decltype(atom)::RefLayout{}) /
static_value<sizeof_bits<typename GEngine::value_type>>();

// smem idx -> smem coord
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
// CTA V -> smem_coord
auto layout_v = composition(inv_smem_layout, num_elems_per_tma);
// Scale that up to cover all of the smem_coords
auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map));
// CTA T -> smem idx
auto layout_t = make_layout(cosize(cta_t_map),
safe_div(num_elems_per_tma, cosize(cta_t_map)));
// CTA TID -> smem coord
auto layout_T =
composition(inv_smem_layout, composition(layout_t, cta_t_map));
// Combine with the T mapping
[[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V);
return TiledCopy<decltype(atom), decltype(layout_TV), decltype(cta_tiler)>{
atom};
}

} // namespace detail

template <class TmaInternalType = void,
class CopyOp,
class GEngine,
class GLayout,
class SLayout,
class CTA_Tiler,
class Cluster_Size>
CUTE_HOST_RTC auto make_gather_tma_copy(
CopyOp const& copy_op,
Tensor<GEngine, GLayout> const& gtensor,
SLayout const& slayout, // ((ATOM_M, m), (ATOM_N, n))
CTA_Tiler const& cta_tiler,
Cluster_Size const& cluster_size) {
// Thr idx -> logical TMA tid
auto cta_t_map = make_layout(cluster_size);
// CTA val idx -> gmem mode
auto cta_v_map = make_identity_layout(shape(gtensor)).compose(cta_tiler);
using TmaType = conditional_t<is_same_v<void, TmaInternalType>,
typename GEngine::value_type,
TmaInternalType>;
return detail::make_gather_tma_copy_tiled<TmaType>(
copy_op, gtensor, slayout, cta_t_map, cta_v_map);
}
} // namespace cute
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cute/tensor.hpp>

#include "cute/swizzle_layout.hpp"
#include "gather_tma_copy.h"
#include "gather_tma_tensor.hpp"

namespace llm {
Expand Down Expand Up @@ -110,11 +111,6 @@ __global__ void tma_test_device_cute(T* g_out,
sizeof(make_tensor_like(tensor<0>(tAsA)));

if (threadIdx.x == 0) {
// print("\n ########### %d ########### \n", stage);
// print("sA: ");
// print(sA);
// print("\n");

/// Initialize shared memory barrier
tma_load_mbar[0] = 0;
cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/);
Expand Down Expand Up @@ -191,8 +187,8 @@ auto test_tma_block_load(CopyOp const& copy_op,
// Create TMA for this device Tensor
Tensor gA =
make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_in.data())), gmem_layout);
auto tma =
make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
auto tma = make_gather_tma_copy<TmaType>(
copy_op, gA, smem_layout, cta_tile, Int<1>{});

// Launch
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
Expand Down Expand Up @@ -266,15 +262,12 @@ auto test_tma_block_load(GMEM_Layout const& gmem_layout,

template <class T, template <typename> typename SWIZZLE_ATOM>
auto test_tma_block_load_swizzle_tile_k(int32_t block_size) {
auto gmem_layout = make_layout(make_shape(256, 256), GenRowMajor{});
auto gmem_layout = make_layout(make_shape(256, 128), GenRowMajor{});

auto gather_shape = Shape<_64, _256>{};
auto gather_shape = Shape<_128, _128>{};
auto gather_gmem_layout = make_layout(gather_shape, GenRowMajor{});
auto smem_layout =
tile_to_shape(SWIZZLE_ATOM<T>{}, gather_shape, Step<_1, _0>{});

// TODO: fix the test failures related to tma box size
// assert (size<1>(SWIZZLE_ATOM<T>{}) != size<1>(smem_layout));
return test_tma_block_load<T>(
gmem_layout, gather_gmem_layout, smem_layout, block_size);
}
Expand All @@ -299,6 +292,9 @@ auto test_tma_block_load(int32_t block_size) {
block_size);
}

TEST(SM120_Tma, Test_Tma_Block_Load) { test_tma_block_load(/*block_size=*/8); }
TEST(SM120_Tma, Test_Tma_Block_Load) {
test_tma_block_load(/*block_size=*/8);
test_tma_block_load(/*block_size=*/16);
}

} // namespace llm