Skip to content

Commit ead9ae2

Browse files
committed
Fix a specialization bug and shMemPool refactor
Fix a bug in specialization where we should run peephole optimization after one iteration of the specialization, otherwise something like sizeof(Type) won't be folded into constant, which will block the specialization of some generics. Refactor the shMemPool type, now we can use the Type as generic parameter to declare the AccelerateVectorCoopMat type.
1 parent 6304d8a commit ead9ae2

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

source/slang/slang-ir-specialize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,7 @@ struct SpecializationContext
12371237
{
12381238
this->changed = true;
12391239
eliminateDeadCode(module->getModuleInst());
1240+
peepholeOptimizeGlobalScope(targetProgram, this->module);
12401241
applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink);
12411242
}
12421243

source/standard-modules/neural/accelerate-vector-coopmat.slang

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,8 +1011,6 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
10111011
public typealias Differential = AccelerateVectorCoopMat<T.Differential, ShMemPool, N, SubgroupSize>;
10121012

10131013
public static const int Size = N;
1014-
1015-
public no_diff SPtr<uint4> sharedMemoryPtr;
10161014
public no_diff ShMemPool shMemPool;
10171015

10181016
private typealias DTypeMatC = half;

source/standard-modules/neural/shared-memory-pool.slang

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ internal typealias SPtr<T> = Ptr<T, Access::ReadWrite, AddressSpace::GroupShared
1414

1515
internal interface ISharedMemoryPool
1616
{
17-
static SPtr<uint4> getPointer();
17+
internal static SPtr<uint4> getPointer();
1818
}
1919

2020
public interface ISharedMemorySize
2121
{
2222
public static const uint Bytes;
2323
}
2424

25-
public struct SharedMemoryPool<int Bytes> : ISharedMemoryPool
25+
public struct SharedMemoryPool<ShMemSize: ISharedMemorySize> : ISharedMemoryPool
2626
{
27-
internal static groupshared uint4 data[Bytes / sizeof(uint4)];
27+
public static const uint sizeInBytes = ShMemSize.Bytes;
28+
internal static groupshared uint4 data[sizeInBytes / sizeof(uint4)];
2829
VISIBILITY_LEVEL static SPtr<uint4> getPointer()
2930
{
3031
return __getAddress(data[0]);

tests/neural/basic-coopmat-vector-test.slang

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ static const int BatchSize = 32;
3939
static const int SubgroupSize = 32;
4040
static const int workgroupCount = BatchSize / SubgroupSize;
4141

42-
typealias ShMemSize = SharedMemorySize<ElementType, TargetEnum.CUDA, ExecutionMode.Training, SubgroupSize, BatchSize/SubgroupSize>;
43-
static const int ShMemSizeInBytes = ShMemSize.OfLayer1<InputSize, OutputSize>.Bytes;
42+
typealias ShMemSize = SharedMemorySize< ElementType, TargetEnum.CUDA, ExecutionMode.Training, SubgroupSize, BatchSize / SubgroupSize>;
43+
typealias ShMemSizeLayer1 = ShMemSize.OfLayer1<InputSize, OutputSize>;
4444

4545
typealias SPtr<T> = Ptr<T, Access::ReadWrite, AddressSpace::GroupShared>;
4646

@@ -73,21 +73,20 @@ OutputVector TestInlineVectorMatMul<InputVector, OutputVector>(
7373
// Basic test on MatMul without bias, this test covers both forward and backward pass
7474
void BasicTestWithoutBias(int tid, int resIndex)
7575
{
76-
typealias ShMemPool = SharedMemoryPool<ShMemSizeInBytes>;
76+
typealias ShMemPool = SharedMemoryPool<ShMemSizeLayer1>;
7777
typealias InVectorType = AccelerateVectorCoopMat<ElementType, ShMemPool, InputSize, SubgroupSize>;
7878
typealias OutVectorType = AccelerateVectorCoopMat<ElementType, ShMemPool, OutputSize, SubgroupSize>;
7979

8080
ElementType[InputSize] inputData = { ElementType(1.0), ElementType(2.0), ElementType(3.0), ElementType(4.0) };
8181
InVectorType input = InVectorType(inputData);
8282

83-
input.sharedMemoryPtr = ShMemPool.getPointer();
84-
8583
BufferStorage weightStorage = BufferStorage(parameters);
8684
BufferStorage dweightStorage = BufferStorage(dParameters);
8785
BufferStorage.Address weightAddress = 0;
8886

8987
// Run the forward pass
9088
let outputVec = TestInlineVectorMatMul<InVectorType, OutVectorType>(input, weightStorage, weightAddress);
89+
9190
// serialRead<16, half>(tid, __getAddress(shMem[0]));
9291
// serialRead<16, half>(tid, __getAddress(shMem[0]) + 32);
9392

@@ -96,7 +95,6 @@ void BasicTestWithoutBias(int tid, int resIndex)
9695
bool isPassed = true;
9796
isPassed = isPassed && (outputVec[0] == 30.0 && outputVec[1] == 70.0);
9897

99-
10098
var weightDiffPair = DifferentialPtrPair<BufferStorage>(weightStorage, dweightStorage);
10199
let dRes = OutVectorType(1.0f);
102100
var dPair = diffPair(input);

0 commit comments

Comments
 (0)