@@ -5910,7 +5910,9 @@ struct WmmaFragment
59105910 if constexpr (sizeof (T) == 4 )
59115911 {
59125912 // T is 32-bit (float or int32): 1 element per register
5913- return *reinterpret_cast <const T*>(®s[index]);
5913+ T v;
5914+ memcpy (&v, ®s[index], 4 );
5915+ return v;
59145916 }
59155917 else if constexpr (sizeof (T) == 2 )
59165918 {
@@ -5921,7 +5923,9 @@ struct WmmaFragment
59215923 int bitOffset = elementOffset * 16 ;
59225924 uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFFFF ;
59235925 uint16_t value16 = static_cast <uint16_t >(extracted);
5924- return *reinterpret_cast <const T*>(&value16);
5926+ T v;
5927+ memcpy (&v, &value16, 2 );
5928+ return v;
59255929 }
59265930 else if constexpr (sizeof (T) == 1 )
59275931 {
@@ -5932,7 +5936,9 @@ struct WmmaFragment
59325936 int bitOffset = elementOffset * 8 ;
59335937 uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFF ;
59345938 uint8_t value8 = static_cast <uint8_t >(extracted);
5935- return *reinterpret_cast <const T*>(&value8);
5939+ T v;
5940+ memcpy (&v, &value8, 1 );
5941+ return v;
59365942 }
59375943 }
59385944
@@ -5942,7 +5948,7 @@ struct WmmaFragment
59425948 if constexpr (sizeof (T) == 4 )
59435949 {
59445950 // T is 32-bit (float or int32): 1 element per register
5945- regs[index] = * reinterpret_cast < const uint32_t *>( &value);
5951+ memcpy (& regs[index], &value, 4 );
59465952 }
59475953 else if constexpr (sizeof (T) == 2 )
59485954 {
@@ -5951,7 +5957,8 @@ struct WmmaFragment
59515957 int elementOffset = index % 2 ;
59525958 int bitOffset = elementOffset * 16 ;
59535959 uint32_t mask = 0xFFFF ;
5954- uint16_t value16 = *reinterpret_cast <const uint16_t *>(&value);
5960+ uint16_t value16;
5961+ memcpy (value16, &value, 2 );
59555962
59565963 // Clear the bits at the target position
59575964 regs[regIndex] &= ~(mask << bitOffset);
@@ -5966,7 +5973,8 @@ struct WmmaFragment
59665973 int elementOffset = index % 4 ;
59675974 int bitOffset = elementOffset * 8 ;
59685975 uint32_t mask = 0xFF ;
5969- uint8_t value8 = *reinterpret_cast <const uint8_t *>(&value);
5976+ uint8_t value8;
5977+ memcpy (value8, &value, 1 );
59705978
59715979 // Clear the bits at the target position
59725980 regs[regIndex] &= ~(mask << bitOffset);
@@ -6007,7 +6015,7 @@ struct WmmaFragment
60076015
60086016 // Maximum registers needed across all fragment types and data types
60096017 static constexpr int MAX_REGS = 8 ;
6010- unsigned regs[MAX_REGS] = {};
6018+ uint32_t regs[MAX_REGS] = {};
60116019
60126020 static constexpr uint32_t elements_per_warp = (R == MatrixUse::MatrixA) ? (M * K)
60136021 : (R == MatrixUse::MatrixB) ? (K * N)
0 commit comments