Skip to content

Commit 3e46ab2

Browse files
committed
Fix incorrect use of reinterpret_cast
The reinterpret_cast cannot be used to cast pointer types that are not pointer-interconvertible. https://eel.is/c++draft/basic.compound#5 https://eel.is/c++draft/conv.qual#2
1 parent 671b391 commit 3e46ab2

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

prelude/slang-cuda-prelude.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5910,7 +5910,9 @@ struct WmmaFragment
59105910
if constexpr (sizeof(T) == 4)
59115911
{
59125912
// T is 32-bit (float or int32): 1 element per register
5913-
return *reinterpret_cast<const T*>(&regs[index]);
5913+
T v;
5914+
memcpy(&v, &regs[index], 4);
5915+
return v;
59145916
}
59155917
else if constexpr (sizeof(T) == 2)
59165918
{
@@ -5921,7 +5923,9 @@ struct WmmaFragment
59215923
int bitOffset = elementOffset * 16;
59225924
uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFFFF;
59235925
uint16_t value16 = static_cast<uint16_t>(extracted);
5924-
return *reinterpret_cast<const T*>(&value16);
5926+
T v;
5927+
memcpy(&v, &value16, 2);
5928+
return v;
59255929
}
59265930
else if constexpr (sizeof(T) == 1)
59275931
{
@@ -5932,7 +5936,9 @@ struct WmmaFragment
59325936
int bitOffset = elementOffset * 8;
59335937
uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFF;
59345938
uint8_t value8 = static_cast<uint8_t>(extracted);
5935-
return *reinterpret_cast<const T*>(&value8);
5939+
T v;
5940+
memcpy(&v, &value8, 1);
5941+
return v;
59365942
}
59375943
}
59385944

@@ -5942,7 +5948,7 @@ struct WmmaFragment
59425948
if constexpr (sizeof(T) == 4)
59435949
{
59445950
// T is 32-bit (float or int32): 1 element per register
5945-
regs[index] = *reinterpret_cast<const uint32_t*>(&value);
5951+
memcpy(&regs[index], &value, 4);
59465952
}
59475953
else if constexpr (sizeof(T) == 2)
59485954
{
@@ -5951,7 +5957,8 @@ struct WmmaFragment
59515957
int elementOffset = index % 2;
59525958
int bitOffset = elementOffset * 16;
59535959
uint32_t mask = 0xFFFF;
5954-
uint16_t value16 = *reinterpret_cast<const uint16_t*>(&value);
5960+
uint16_t value16;
5961+
memcpy(value16, &value, 2);
59555962

59565963
// Clear the bits at the target position
59575964
regs[regIndex] &= ~(mask << bitOffset);
@@ -5966,7 +5973,8 @@ struct WmmaFragment
59665973
int elementOffset = index % 4;
59675974
int bitOffset = elementOffset * 8;
59685975
uint32_t mask = 0xFF;
5969-
uint8_t value8 = *reinterpret_cast<const uint8_t*>(&value);
5976+
uint8_t value8;
5977+
memcpy(value8, &value, 1);
59705978

59715979
// Clear the bits at the target position
59725980
regs[regIndex] &= ~(mask << bitOffset);
@@ -6007,7 +6015,7 @@ struct WmmaFragment
60076015

60086016
// Maximum registers needed across all fragment types and data types
60096017
static constexpr int MAX_REGS = 8;
6010-
unsigned regs[MAX_REGS] = {};
6018+
uint32_t regs[MAX_REGS] = {};
60116019

60126020
static constexpr uint32_t elements_per_warp = (R == MatrixUse::MatrixA) ? (M * K)
60136021
: (R == MatrixUse::MatrixB) ? (K * N)

0 commit comments

Comments
 (0)