Skip to content

Commit 0455177

Browse files
committed
update the name
1 parent 2878f4e commit 0455177

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

source/standard-modules/neural/accelerate-vector-coopmat.slang

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,12 +984,14 @@ VISIBILITY_LEVEL struct MMAHelper<T, int InputSize, int OutputSize, int Subgroup
984984
}
985985

986986
// Cooperative matrix is only supported by CUDA and SPIR-V
987+
// WaveTangledVector is a vector type that emulates the Cooperative Vector type by using Cooperative Matrix feature which is
988+
// supported by CUDA and SPIR-V.
987989
[require(cooperative_matrix, subgroup_basic)]
988-
public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, int SubgroupSize> : IVector<T>
990+
public struct WaveTangledVector<T, ShMemPool : ISharedMemoryPool, int N, int SubgroupSize> : IVector<T>
989991
where T : __BuiltinFloatingPointType
990992
where T.Differential == T
991993
{
992-
public typealias Differential = AccelerateVectorCoopMat<T.Differential, ShMemPool, N, SubgroupSize>;
994+
public typealias Differential = WaveTangledVector<T.Differential, ShMemPool, N, SubgroupSize>;
993995

994996
public static const int Size = N;
995997
public no_diff ShMemPool shMemPool;

tests/neural/basic-coopmat-vector-test.slang

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ OutputVector MatVecMulAdd<InputVector, OutputVector>(
7474
void BasicTestWithoutBias(int tid, int resIndex)
7575
{
7676
typealias ShMemPool = SharedMemoryPool<ShMemSizeLayer1>;
77-
typealias InVectorType = AccelerateVectorCoopMat<ElementType, ShMemPool, InputSize, SubgroupSize>;
78-
typealias OutVectorType = AccelerateVectorCoopMat<ElementType, ShMemPool, OutputSize, SubgroupSize>;
77+
typealias InVectorType = WaveTangledVector<ElementType, ShMemPool, InputSize, SubgroupSize>;
78+
typealias OutVectorType = WaveTangledVector<ElementType, ShMemPool, OutputSize, SubgroupSize>;
7979

8080
ElementType[InputSize] inputData = { ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0) };
8181
InVectorType input = InVectorType(inputData);
@@ -128,8 +128,8 @@ void BasicTestWithoutBias(int tid, int resIndex)
128128
void BasicTestWithBias(int tid, int resIndex)
129129
{
130130
typealias ShMemPool = SharedMemoryPool<ShMemSizeLayer1>;
131-
typealias InVectorType = AccelerateVectorCoopMat<ElementType, ShMemPool, InputSize, SubgroupSize>;
132-
typealias OutVectorType = AccelerateVectorCoopMat<ElementType, ShMemPool, OutputSize, SubgroupSize>;
131+
typealias InVectorType = WaveTangledVector<ElementType, ShMemPool, InputSize, SubgroupSize>;
132+
typealias OutVectorType = WaveTangledVector<ElementType, ShMemPool, OutputSize, SubgroupSize>;
133133

134134
ElementType[InputSize] inputData = { ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0) };
135135
InVectorType input = InVectorType(inputData);

0 commit comments

Comments
 (0)