@@ -22,22 +22,62 @@ CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a,
2222 return elem_less (get<I>(a), get<I>(b));
2323}
2424
25- template <bool EVEN_K,
26- bool EVEN_MN,
27- bool ZERO_FILL_MN,
28- bool ZERO_FILL_K,
29- class TiledCopy ,
25+ template <class Copy_Atom , class TensorS , class TensorD >
26+ CUTE_HOST_DEVICE void zfill (const Copy_Atom& copy_atom,
27+ const TensorS& src,
28+ TensorD&& dst) {
29+ CUTE_STATIC_ASSERT (TensorS::rank == TensorD::rank, " rank-mismatch." );
30+
31+ auto has_with_bool = cute::is_valid (
32+ [](auto t) -> void_t <decltype (declval<typename decltype (t)::Traits>()
33+ .with (true ))> {},
34+ copy_atom);
35+ if constexpr (has_with_bool) {
36+ constexpr int R = TensorD::rank;
37+ if constexpr (R == 1 ) { // Dispatch the copy
38+ copy_atom.with (false ).call (src, dst);
39+ } else { // Loop over all but the first mode
40+ Tensor src_v = group_modes<1 , R>(src);
41+ Tensor dst_v = group_modes<1 , R>(dst);
42+ CUTE_UNROLL
43+ for (int i = 0 ; i < size<1 >(dst_v); ++i) {
44+ copy_atom.with (false ).call (src_v (_, i), dst_v (_, i));
45+ }
46+ }
47+ } else {
48+ // just call clear if no with method
49+ clear (dst);
50+ }
51+ }
52+
53+ template <class Copy_Atom , class TensorS , class TensorD >
54+ CUTE_HOST_DEVICE void zfill (const Copy_Atom& copy_atom,
55+ const TensorS& src,
56+ TensorD& dst) {
57+ zfill (copy_atom, src, dst);
58+ }
59+
60+ template <bool EVEN_MN,
61+ bool EVEN_K,
62+ bool ZFILL_MN = true ,
63+ bool ZFILL_K = true ,
64+ class CopyAtom ,
65+ class TV ,
66+ class Tiler ,
3067 class TensorS ,
3168 class TensorD ,
3269 class TensorC ,
3370 class Coord >
3471CUTE_HOST_DEVICE void safe_copy (
35- const TiledCopy& tiled_copy,
72+ const TiledCopy<CopyAtom, TV, Tiler> & tiled_copy,
3673 const TensorS& src, // (CPY, CPY_M/N, CPY_K)
3774 TensorD& dst, // (CPY, CPY_M/N, CPY_K)
3875 const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k)
3976 const Coord& max_coord // max_coord(blk_m/n, blk_k)
4077) {
78+ CUTE_STATIC_ASSERT (TensorS::rank == TensorD::rank, " rank-mismatch." );
79+ auto copy_atom = static_cast <const CopyAtom&>(tiled_copy);
80+
4181 if constexpr (!EVEN_MN && !EVEN_K) {
4282 // handle both m/n and k oob
4383 CUTE_UNROLL
@@ -46,16 +86,16 @@ CUTE_HOST_DEVICE void safe_copy(
4686 CUTE_UNROLL
4787 for (int ki = 0 ; ki < size<2 >(src); ++ki) {
4888 if (elem_less<1 >(identity (_0{}, _0{}, ki), max_coord)) {
49- copy (tiled_copy , src (_, mi, ki), dst (_, mi, ki));
89+ copy (copy_atom , src (_, mi, ki), dst (_, mi, ki));
5090 } else {
51- if constexpr (ZERO_FILL_K ) {
52- clear ( dst (_, mi, ki));
91+ if constexpr (ZFILL_K ) {
92+ zfill (copy_atom, src (_, mi, ki), dst (_, mi, ki));
5393 }
5494 }
5595 }
5696 } else {
57- if constexpr (ZERO_FILL_MN ) {
58- clear ( dst (_, mi, _));
97+ if constexpr (ZFILL_MN ) {
98+ zfill (copy_atom, src (_, mi, _), dst (_, mi, _));
5999 }
60100 }
61101 }
@@ -64,10 +104,10 @@ CUTE_HOST_DEVICE void safe_copy(
64104 CUTE_UNROLL
65105 for (int mi = 0 ; mi < size<1 >(src); ++mi) {
66106 if (elem_less<0 >(identity (_0{}, mi, _0{}), max_coord)) {
67- copy (tiled_copy , src (_, mi, _), dst (_, mi, _));
107+ copy (copy_atom , src (_, mi, _), dst (_, mi, _));
68108 } else {
69- if constexpr (ZERO_FILL_MN ) {
70- clear ( dst (_, mi, _));
109+ if constexpr (ZFILL_MN ) {
110+ zfill (copy_atom, src (_, mi, _), dst (_, mi, _));
71111 }
72112 }
73113 }
@@ -76,16 +116,16 @@ CUTE_HOST_DEVICE void safe_copy(
76116 CUTE_UNROLL
77117 for (int ki = 0 ; ki < size<2 >(src); ++ki) {
78118 if (elem_less<1 >(identity (_0{}, _0{}, ki), max_coord)) {
79- copy (tiled_copy , src (_, _, ki), dst (_, _, ki));
119+ copy (copy_atom , src (_, _, ki), dst (_, _, ki));
80120 } else {
81- if constexpr (ZERO_FILL_K ) {
82- clear ( dst (_, _, ki));
121+ if constexpr (ZFILL_K ) {
122+ zfill (copy_atom, src (_, _, ki), dst (_, _, ki));
83123 }
84124 }
85125 }
86126 } else {
87127 // no oob, just copy
88- copy (tiled_copy , src, dst);
128+ copy (copy_atom , src, dst);
89129 }
90130}
91131
0 commit comments