Skip to content

Commit 48d73cd

Browse files
authored
neural: Fix a sync issue (#10281)
Close #10272. The root cause of this issue is the race condition. Add two syncs to avoid this issue. 1. At beginning of the mma loop -> This prevent the fast warp start writing the shared memory A in the i+1 iter when other slow warps are still reading the shared memory A. 2. In the backward, after mma and before outerproduceAccumulate -> becuase mma ends with warp-sync, so outerproduceAccumulate could start executing for some fast warps while the slow warps are still running mma. This PR also fix a numerical issue in the activation functions. Basically every function using exp() is not numerical stable, so update them.
1 parent d76973a commit 48d73cd

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

source/standard-modules/neural/accelerate-vector-coopmat.slang

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,11 @@ VISIBILITY_LEVEL struct MMAHelper<T, int InputSize, int OutputSize, int Subgroup
10441044
}
10451045
else
10461046
{
1047+
// Ensure all warps finished reading shared memory from the previous
1048+
// tile iteration (or from the previous layer's output writeback in
1049+
// fused multi-layer kernels) before overwriting with new tile data.
1050+
GroupMemoryBarrierWithGroupSync();
1051+
10471052
loadShA<U, Address>(ptrAOffset[0], tileIndex, weightAddress);
10481053
loadVectorToShB(ptrBOffset[0], tileIndex, subgroupIndex, inputVector);
10491054

@@ -1271,6 +1276,13 @@ public struct WaveTangledVector<T, ShMemSize : ISharedMemorySize, int N, int Sub
12711276
// outerProductAccumulate uses per-warp shared memory for both A (dOutput vectors)
12721277
// and B (input vectors). The B region must start after ALL per-warp A regions to
12731278
// avoid overlapping writes between warps (warp i's A would alias warp (i-1)'s B).
1279+
//
1280+
// Sync required: mma() above ends with warp-level sync only, but
1281+
// outerProductAccumulate reuses the same shared memory pool (starting at shA=0).
1282+
// Without a group sync, fast warps starting the outer product would corrupt
1283+
// slow warps still reading from mma's output writeback.
1284+
GroupMemoryBarrierWithGroupSync();
1285+
12741286
static const int _outerRows = (MMA.M + MMA.CMShape.ROW_A - 1) / MMA.CMShape.ROW_A;
12751287
static const int _outerPerWarpA = _outerRows * MMA.CMShape.ROW_A / MMA.CMShape.ElementCountPerVector * MMA.CMShape.COLUMN_A;
12761288
uint shB_outer = shA + uint(getWaveCount() * _outerPerWarpA);
@@ -1284,6 +1296,8 @@ public struct WaveTangledVector<T, ShMemSize : ISharedMemorySize, int N, int Sub
12841296
>( doutput, shA, shB, shC, dWeightAddress.p, none);
12851297
dthis = DifferentialPair<This>(dthis.p, dInput);
12861298
1299+
GroupMemoryBarrierWithGroupSync();
1300+
12871301
static const int _outerRows2 = (MMA.M + MMA.CMShape.ROW_A - 1) / MMA.CMShape.ROW_A;
12881302
static const int _outerPerWarpA2 = _outerRows2 * MMA.CMShape.ROW_A / MMA.CMShape.ElementCountPerVector * MMA.CMShape.COLUMN_A;
12891303
uint shB_outer2 = shA + uint(getWaveCount() * _outerPerWarpA2);

source/standard-modules/neural/activations.slang

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ public struct Sigmoid<T> : IActivation<T>
9494
[ForceUnroll]
9595
for (int i = 0; i < Vector.Size; i++)
9696
{
97+
// Numerically stable: exp() argument is always <= 0
9798
let x = input[i];
98-
output[i] = T(1) / (T(1) + exp(-x));
99+
let ex = exp(x >= T(0) ? -x : x);
100+
output[i] = x >= T(0) ? T(1) / (T(1) + ex) : ex / (T(1) + ex);
99101
}
100102
return output;
101103
}
@@ -185,8 +187,10 @@ public struct SiLU<T> : IActivation<T>
185187
[ForceUnroll]
186188
for (int i = 0; i < Vector.Size; i++)
187189
{
190+
// x * sigmoid(x), numerically stable: exp() argument is always <= 0
188191
let x = input[i];
189-
output[i] = x / (T(1) + exp(-x)); // x * sigmoid(x)
192+
let ex = exp(x >= T(0) ? -x : x);
193+
output[i] = x >= T(0) ? x / (T(1) + ex) : x * ex / (T(1) + ex);
190194
}
191195
return output;
192196
}
@@ -212,8 +216,11 @@ public struct QuickGELU<T> : IActivation<T>
212216
[ForceUnroll]
213217
for (int i = 0; i < Vector.Size; i++)
214218
{
219+
// x * sigmoid(1.702 * x), numerically stable: exp() argument is always <= 0
215220
let x = input[i];
216-
output[i] = x / (T(1) + exp(T(-1.702) * x)); // x * sigmoid(1.702 * x)
221+
let sx = T(1.702) * x;
222+
let ex = exp(sx >= T(0) ? -sx : sx);
223+
output[i] = sx >= T(0) ? x / (T(1) + ex) : x * ex / (T(1) + ex);
217224
}
218225
return output;
219226
}

0 commit comments

Comments
 (0)