Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
44f81cb
implement coopmat for neural
kaizhangNV Nov 21, 2025
99aa203
implement test for tile shared mem load
kaizhangNV Nov 24, 2025
ab6a309
Finish aligned test for shA and shB load
kaizhangNV Nov 24, 2025
ee03a40
finish unaligned test for shA and shB load
kaizhangNV Nov 26, 2025
3c12072
refactor vectorized reader
kaizhangNV Nov 26, 2025
a2e41ef
Add more tests for unAligned size load
kaizhangNV Nov 28, 2025
8c42bf2
refine tests and fix issue in matB load
kaizhangNV Nov 28, 2025
b76d51c
start testing mma
kaizhangNV Dec 1, 2025
ab2657f
drop the layout template paremter from wmma intrinsic
kaizhangNV Dec 1, 2025
ee781c2
testing for mma
kaizhangNV Dec 3, 2025
430e1ee
start testing arbitrary sized inputs
kaizhangNV Dec 5, 2025
cfad1d9
finish single/multi warps test for mma
kaizhangNV Dec 5, 2025
afab86a
Fix test bug
kaizhangNV Dec 8, 2025
83d05cb
WIP: AccelerateVectorCoopMat impl
kaizhangNV Dec 8, 2025
89f8877
implement loading transposed matrix A
kaizhangNV Dec 10, 2025
b716c93
name refactor
kaizhangNV Dec 10, 2025
1ace626
generalize mma interface and refactor name again
kaizhangNV Dec 10, 2025
3367e64
add transpose mma test
kaizhangNV Dec 10, 2025
4df3cdc
add more transpose tests and fix bugs
kaizhangNV Dec 11, 2025
341a58b
add group sync for all multi-warps tests
kaizhangNV Dec 11, 2025
54de96e
implement backward partially and fix a stupid error in inline vector
kaizhangNV Dec 11, 2025
8a5e3bd
impl outproductAccumulate
kaizhangNV Dec 16, 2025
ac24b6e
debug
kaizhangNV Dec 17, 2025
64365aa
Re-write the logic to get coopMat tile size
kaizhangNV Dec 18, 2025
beec35f
fix issue in the coopmat intrinsic
kaizhangNV Dec 22, 2025
a14a5eb
finish outterproduct implementation but need writeAtomic to storage b…
kaizhangNV Dec 22, 2025
47b2ba6
define an IArrayAccessor to unify read/write operation
kaizhangNV Dec 22, 2025
d780f89
refactor readUint4Aligned
kaizhangNV Dec 22, 2025
1aea684
refactor writeUint4Aligned
kaizhangNV Dec 22, 2025
e526eb2
combine readUnit4Aligned and writeUint4Aligned
kaizhangNV Dec 22, 2025
5e1e8d4
unify readuint4/writeunit4/atomicAddUint4
kaizhangNV Dec 22, 2025
b8b5626
fix all the known issues for outerproductAccumulate
kaizhangNV Dec 23, 2025
2444160
use float as storage type for Vulkan test
kaizhangNV Dec 24, 2025
408792d
fix an issue when writing back result to global memory
kaizhangNV Dec 24, 2025
bed965c
improve the vectorized read/write
kaizhangNV Dec 24, 2025
548b964
update two tests to cover both float/half type input buffer
kaizhangNV Dec 24, 2025
ec377e1
test forward and backward pass for coopmat backend
kaizhangNV Dec 24, 2025
8fa458d
refactor load/store vector
kaizhangNV Dec 25, 2025
492b0f5
update MMAHelper to support both half and float for MatC
kaizhangNV Dec 25, 2025
b8898c6
udpate MMAHelper to support both half and float for MatC for outerpro…
kaizhangNV Dec 26, 2025
78a950a
fix an alignment issue
kaizhangNV Dec 29, 2025
3aa0f19
add utility to calc shared memory size
kaizhangNV Dec 29, 2025
12f1d10
add shared memory size calculator
kaizhangNV Dec 30, 2025
933d98d
add test for shared memory size calculator
kaizhangNV Jan 1, 2026
1044322
update basic_coopmat test to use SharedMemoryPool
kaizhangNV Jan 1, 2026
6bb3e9a
refactor to remove the N from IVector type
kaizhangNV Jan 1, 2026
9106a9f
add support for bias
kaizhangNV Jan 2, 2026
8c96e81
Fix a specialization bug and shMemPool refactor
kaizhangNV Jan 5, 2026
f56586b
add test for mma with bias
kaizhangNV Jan 6, 2026
91f877f
fix a test failure
kaizhangNV Jan 6, 2026
6e4d2c1
refactor test
kaizhangNV Jan 6, 2026
176e08b
fix test issue
kaizhangNV Jan 7, 2026
31dd00c
keep refactoring tests
kaizhangNV Jan 7, 2026
86747e7
finish basic coopmat test
kaizhangNV Jan 7, 2026
b50b025
make the mat store coherent
kaizhangNV Jan 7, 2026
70044f4
make the subgroup scope sync uniform cross whole workgroup
kaizhangNV Jan 9, 2026
3d56a78
remove the warp sync after wmma.store
kaizhangNV Jan 12, 2026
a6cdb49
WIP: run CI on the upgraded driver runner
kaizhangNV Jan 12, 2026
94e264e
remove the WAR because of the driver bug
kaizhangNV Jan 12, 2026
c1630a3
adjust the value scale in a failed test
kaizhangNV Jan 12, 2026
02606a3
WIP:run CI on the upgraded driver runner
kaizhangNV Jan 12, 2026
62ce1af
DNI: try to increase the timeout time
kaizhangNV Jan 12, 2026
305622f
add develop build mechanism
kaizhangNV Jan 13, 2026
d98ef9b
DNI: CI doesn't run
kaizhangNV Jan 13, 2026
e680f15
inline everything
kaizhangNV Jan 13, 2026
a9212fd
Revert "DNI: CI doesn't run"
kaizhangNV Jan 13, 2026
3b30e80
Revert "DNI: try to increase the timeout time"
kaizhangNV Jan 13, 2026
9badc94
fix the last bug in the tests
kaizhangNV Jan 14, 2026
a05be84
run tests on -emit-spirv-directly mode only
kaizhangNV Jan 14, 2026
accf5e0
reserve a layout generic parameter for future optimization
kaizhangNV Jan 15, 2026
2c5eae5
revert CI change to run the job on any runners
kaizhangNV Jan 26, 2026
629fe0f
change ci-slang-test timeout to 60 minutes
kaizhangNV Jan 26, 2026
af85197
change slangpy-test timeout to 60 minutes
kaizhangNV Jan 26, 2026
544aaf2
update the name
kaizhangNV Jan 27, 2026
2beb8cc
increase the timeout for slang-test
kaizhangNV Jan 27, 2026
f9df8c3
implement a skip-list in the slang-test
kaizhangNV Jan 28, 2026
7b6b216
Merge branch 'master' into coopMat-impl
expipiplus1 Jan 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/ci-slang-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ on:
jobs:
test-slang:
runs-on: ${{ fromJSON(inputs.runs-on) }}
timeout-minutes: 30
timeout-minutes: 60
defaults:
run:
shell: bash
Expand Down Expand Up @@ -93,6 +93,11 @@ jobs:
slang_test_args+=("-expected-failure-list" "tests/expected-failure-linux-gpu.txt")
fi

# Skip most neural tests in debug builds to reduce CI time
if [[ "${{ inputs.config }}" == "debug" ]]; then
slang_test_args+=("-skip-list" "tests/skip-list-debug.txt")
fi

# Execute slang-test with all arguments
"$bin_dir/slang-test" "${slang_test_args[@]}"
- name: Run Slang examples
Expand Down Expand Up @@ -174,7 +179,7 @@ jobs:
# Includes both slangpy pytest tests and slangpy-samples examples.
test-slangpy:
runs-on: ${{ fromJSON(inputs.runs-on) }}
timeout-minutes: 30
timeout-minutes: 60
if: inputs.full-gpu-tests && (github.event_name == 'pull_request' || inputs.config == 'release')
defaults:
run:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ jobs:
-DSLANG_GENERATORS_PATH=build-platform-generators/bin \
-DSLANG_ENABLE_EXAMPLES=OFF \
-DSLANG_ENABLE_RELEASE_LTO=ON \
-DSLANG_STANDARD_MODULE_DEVELOP_BUILD=OFF \
"-DSLANG_SLANG_LLVM_FLAVOR=$(
[[ "${{matrix.build-slang-llvm}}" = "true" ]] && echo "USE_SYSTEM_LLVM" || echo "DISABLE")" \
${{matrix.extra-cmake-flags}}
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ option(SLANG_ENABLE_IR_BREAK_ALLOC "Enable _debugUID on IR allocation")
option(SLANG_ENABLE_ASAN "Enable ASAN (address sanitizer)")
option(SLANG_ENABLE_COVERAGE "Enable code coverage instrumentation")

option(
SLANG_STANDARD_MODULE_DEVELOP_BUILD
"Enable development build for standard modules (enables UNIT_TEST macro). Disable for release builds."
ON
)

option(SLANG_ENABLE_PREBUILT_BINARIES "Enable using prebuilt binaries" ON)
option(SLANG_ENABLE_GFX "Enable gfx targets" ON)
option(SLANG_ENABLE_SLANGD "Enable language server target" ON)
Expand Down
152 changes: 99 additions & 53 deletions prelude/slang-cuda-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -6462,7 +6462,7 @@ struct IsSaturated<false>
// ====================================================================================

template<typename ElemT, int M, int N, int K, MatrixUse use, Layout layout>
__device__ inline void wmmaLoad(uint32_t* regs, const ElemT* ptr, int stride)
__device__ inline void wmmaLoad(uint32_t* regs, const void* ptr, int stride)
{
constexpr int nregs = RegisterCount<ElemT, M, N, K, use>::value;

Expand Down Expand Up @@ -6527,7 +6527,7 @@ __device__ inline void wmmaLoad(uint32_t* regs, const ElemT* ptr, int stride)
// ====================================================================================

template<typename ElemT, int M, int N, int K, Layout layout>
__device__ inline void wmmaStore(ElemT* ptr, const uint32_t* regs, int stride)
__device__ inline void wmmaStore(void* ptr, const uint32_t* regs, int stride)
{
constexpr int nregs = RegisterCount<ElemT, M, N, K, MatrixUse::MatrixD>::value;

Expand Down Expand Up @@ -6623,7 +6623,7 @@ inline unsigned __device__ Pack32Helper<unsigned char>(unsigned char value)

// The dimensions of the fragment are specified by M, N, K which are totally determined during
// compile time, so slang already did the pre-filter on the shape & type combination.
template<typename T, int M, int N, int K, MatrixUse R, Layout MatrixLayout = RowMajor>
template<typename T, int M, int N, int K, MatrixUse R>
struct WmmaFragment
{
__device__ WmmaFragment() {}
Expand All @@ -6647,25 +6647,19 @@ struct WmmaFragment
void __device__ fill(T value)
{
unsigned packed = Pack32Helper(value);

// Manually assign to prevent register coalescing
regs[0] = packed;
regs[1] = packed;
regs[2] = packed;
regs[3] = packed;
regs[4] = packed;
regs[5] = packed;
regs[6] = packed;
regs[7] = packed;
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
for (int i = 0; i < nregs; i++)
{
regs[i] = packed;
}
}

__device__ This operator*(T b)
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

// This loop will be unrolled by the compiler becuase nregs is constexpr
for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, get(i) * b);
}
Expand All @@ -6674,11 +6668,10 @@ struct WmmaFragment

__device__ This operator*(const This& b)
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

// This loop will be unrolled by the compiler becuase nregs is constexpr
for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, get(i) * b.get(i));
}
Expand All @@ -6687,10 +6680,9 @@ struct WmmaFragment

__device__ This operator/(const This& other)
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, get(i) / other.get(i));
}
Expand All @@ -6699,10 +6691,9 @@ struct WmmaFragment

__device__ This operator-(const This& other)
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, get(i) - other.get(i));
}
Expand All @@ -6711,10 +6702,9 @@ struct WmmaFragment

__device__ This operator-()
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, -get(i));
}
Expand All @@ -6723,10 +6713,9 @@ struct WmmaFragment

__device__ This operator+(const This& other)
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, get(i) + other.get(i));
}
Expand All @@ -6735,10 +6724,9 @@ struct WmmaFragment

__device__ This operator%(const This& other)
{
constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
This result;

for (int i = 0; i < nregs; i++)
for (int i = 0; i < GetLength(); i++)
{
result.set(i, get(i) % other.get(i));
}
Expand All @@ -6751,7 +6739,7 @@ struct WmmaFragment
// If the data type is different, we need to copy element by element.
// Since the shape of two matrices are the same, they have the same
// number of elements.
for (int i = 0; i < elements_per_thread; i++)
for (int i = 0; i < GetLength(); i++)
{
set(i, static_cast<T>(other.get(i)));
}
Expand All @@ -6763,7 +6751,6 @@ struct WmmaFragment
// - index 1: bits [8:15] of regs[0]
// - index 2: bits [16:23] of regs[0]
// - index 3: bits [24:31] of regs[0]
// - index 4: bits [0:7] of regs[1], etc.
__device__ T get(int index) const
{
if constexpr (sizeof(T) == 4)
Expand Down Expand Up @@ -6848,40 +6835,52 @@ struct WmmaFragment
wmmaStore<T, M, N, K, layout>(buffer + element, regs, stride);
}

template<Layout layout, typename U>
void __device__ Store(U* buffer, uint stride)
{
// Force compile-time check, so we know the template parameter comibination is valid.
(void)RegisterCount<T, M, N, K, R>::value;
wmmaStore<T, M, N, K, layout>(buffer, regs, stride * sizeof(U) / sizeof(T));
}

template<Layout layout>
static This __device__ Load(T* buffer, uint element, uint stride)
{
WmmaFragment<T, M, N, K, R, layout> fragment;
WmmaFragment<T, M, N, K, R> fragment;

// Force compile-time check, so we know the template parameter comibination is valid.
(void)RegisterCount<T, M, N, K, R>::value;
wmmaLoad<T, M, N, K, R, layout>(fragment.regs, buffer + element, stride);
fragment.m_layout = layout;
return fragment;
}

template<Layout layout, typename U>
static This __device__ Load(U* buffer, uint stride)
{
WmmaFragment<T, M, N, K, R> fragment;

// Force compile-time check, so we know the template parameter comibination is valid.
(void)RegisterCount<T, M, N, K, R>::value;
wmmaLoad<T, M, N, K, R, layout>(fragment.regs, buffer, stride * sizeof(U) / sizeof(T));
fragment.m_layout = layout;
return fragment;
}

static __device__ uint32_t GetLength() { return This::elements_per_thread; }
static constexpr __device__ uint32_t GetLength() { return This::elements_per_thread; }

// For referencing those template parameters outside the struct
using ElementType = T;
static constexpr int m_M = M;
static constexpr int m_N = N;
static constexpr int m_K = K;
static constexpr Layout m_layout = MatrixLayout;

// Maximum registers needed across all fragment types and data types
static constexpr int MAX_REGS = 8;
uint32_t regs[MAX_REGS] = {};
Layout m_layout = Layout::RowMajor;

static constexpr uint32_t elements_per_warp = (R == MatrixUse::MatrixA) ? (M * K)
: (R == MatrixUse::MatrixB) ? (K * N)
: (M * N);
// Register Count requirement
static constexpr int RegsCount = RegisterCount<T, M, N, K, R>::value;
unsigned regs[RegsCount] = {};

static_assert(elements_per_warp % 32 == 0, "Total elements per warp must be divisible by 32");

static constexpr uint32_t elements_per_thread = elements_per_warp / 32;
static constexpr uint32_t bytes_per_thread = elements_per_thread * sizeof(T);
static constexpr uint32_t registers_per_thread = (bytes_per_thread + 3) / 4;
static constexpr uint32_t elements_per_thread = RegsCount * (4 / sizeof(T));
};

// ====================================================================================
Expand Down Expand Up @@ -7350,12 +7349,10 @@ template<
int M,
int N,
int K,
Layout layoutA,
Layout layoutB,
bool saturatingAccumulation>
WmmaFragment<DType, M, N, K, MatrixC> __device__ coopMatMulAdd(
WmmaFragment<AType, M, N, K, MatrixUse::MatrixA, layoutA> matA,
WmmaFragment<BType, M, N, K, MatrixUse::MatrixB, layoutB> matB,
WmmaFragment<AType, M, N, K, MatrixUse::MatrixA> matA,
WmmaFragment<BType, M, N, K, MatrixUse::MatrixB> matB,
WmmaFragment<CType, M, N, K, MatrixUse::MatrixC> matC)
{
constexpr ShapeCombination shape = (M == 16 && N == 16 && K == 16) ? ShapeCombination::m16n16k16
Expand All @@ -7364,11 +7361,60 @@ WmmaFragment<DType, M, N, K, MatrixC> __device__ coopMatMulAdd(
: ShapeCombination::m32n8k16;

WmmaFragment<DType, M, N, K, MatrixC> matD;
MMAHelper<AType, BType, CType, DType, shape, layoutA, layoutB, saturatingAccumulation>::eval(
matD,
matA,
matB,
matC);
uint32_t encodedLayout = (matA.m_layout == Layout::RowMajor ? 1 : 0) << 1 |
(matB.m_layout == Layout::RowMajor ? 1 : 0);

switch (encodedLayout)
{
// 00011
case 0x3:
MMAHelper<
AType,
BType,
CType,
DType,
shape,
Layout::RowMajor,
Layout::RowMajor,
saturatingAccumulation>::eval(matD, matA, matB, matC);
break;
// 00010
case 0x2:
MMAHelper<
AType,
BType,
CType,
DType,
shape,
Layout::RowMajor,
Layout::ColMajor,
saturatingAccumulation>::eval(matD, matA, matB, matC);
break;
// 0001
case 0x01:
MMAHelper<
AType,
BType,
CType,
DType,
shape,
Layout::ColMajor,
Layout::RowMajor,
saturatingAccumulation>::eval(matD, matA, matB, matC);
break;
// 0000
case 0x00:
MMAHelper<
AType,
BType,
CType,
DType,
shape,
Layout::ColMajor,
Layout::ColMajor,
saturatingAccumulation>::eval(matD, matA, matB, matC);
break;
}

return matD;
}
Expand Down
Loading