Skip to content

Commit 6c183b1

Browse files
committed
reserve a layout generic parameter for future optimization
1 parent 87d1b1c commit 6c183b1

File tree

9 files changed

+81
-50
lines changed

9 files changed

+81
-50
lines changed

source/standard-modules/neural/accelerate-vector-coopmat.slang

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ VISIBILITY_LEVEL struct MMAHelper<T, int InputSize, int OutputSize, int Subgroup
711711
break;
712712

713713
let offset = Storage.getOffset(biasAddress, index);
714-
biasStorage.write(offset, localResult[index]);
714+
biasStorage.atomicAdd(offset, localResult[index]);
715715
}
716716
}
717717
}
@@ -1073,7 +1073,7 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
10731073
set { this.data[index] = newValue; }
10741074
}
10751075

1076-
private OutputVector linearTransformOnTarget<Storage, OutputVector, TargetEnum Target, bool Bias>(
1076+
private OutputVector linearTransformOnTarget<Storage, Layout, OutputVector, TargetEnum Target, bool Bias>(
10771077
Storage weight,
10781078
no_diff Storage.Address weightAddress,
10791079
no_diff Optional<Storage> bias,
@@ -1093,7 +1093,7 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
10931093
return OutputVector(outputArray);
10941094
}
10951095

1096-
private static void linearTransformBwdOnTarget< Storage, OutputVector, TargetEnum Target, bool Bias>(
1096+
private static void linearTransformBwdOnTarget<Storage, Layout, OutputVector, TargetEnum Target, bool Bias>(
10971097
inout DifferentialPair<This> dthis,
10981098
DifferentialPtrPair<Storage> dWeightStorage,
10991099
no_diff Storage.Address dWeightAddress,
@@ -1158,70 +1158,71 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
11581158
// Linear transformation without bias
11591159
[Differentiable]
11601160
[BackwardDerivative(linearTransformBwd)]
1161-
public OutputVector linearTransform<Storage, OutputVector>(
1161+
public OutputVector linearTransform<Storage, Layout, OutputVector>(
11621162
Storage weightStorage,
11631163
no_diff Storage.Address weightAddress)
11641164
where Storage : IStorage<T>
11651165
where Storage.Differential : IStorage<T.Differential>
11661166
where Storage.Address == Storage.Differential.Address
1167+
where Layout : IStorageLayout
11671168
where OutputVector : IVector<T>
11681169
{
11691170
__target_switch
11701171
{
11711172
case cuda:
1172-
return no_diff linearTransformOnTarget<Storage, OutputVector, TargetEnum.CUDA, false>(weightStorage, weightAddress, none, none);
1173+
return no_diff linearTransformOnTarget<Storage, Layout, OutputVector, TargetEnum.CUDA, false>(weightStorage, weightAddress, none, none);
11731174
case spirv:
1174-
return no_diff linearTransformOnTarget<Storage, OutputVector, TargetEnum.SPIR_V, false>(weightStorage, weightAddress, none, none);
1175+
return no_diff linearTransformOnTarget<Storage, Layout, OutputVector, TargetEnum.SPIR_V, false>(weightStorage, weightAddress, none, none);
11751176
}
11761177
}
11771178

11781179
// Backward of linear transformation without bias
1179-
static void linearTransformBwd<Storage, OutputVector>(
1180+
static void linearTransformBwd<Storage, Layout, OutputVector>(
11801181
inout DifferentialPair<This> dthis,
11811182
DifferentialPtrPair<Storage> dWeightStorage,
11821183
no_diff Storage.Address dWeightAddress,
11831184
OutputVector.Differential doutput)
11841185
where Storage : IStorage<T>
11851186
where Storage.Differential : IStorage<T.Differential>
11861187
where Storage.Address == Storage.Differential.Address
1188+
where Layout : IStorageLayout
11871189
where OutputVector : IVector<T>
11881190
where OutputVector.Differential : IVector<T.Differential>
11891191
{
11901192
Optional<DifferentialPtrPair<Storage>> biasStorage = {};
11911193
__target_switch
11921194
{
11931195
case cuda:
1194-
linearTransformBwdOnTarget<Storage, OutputVector, TargetEnum.CUDA, false>(dthis, dWeightStorage, dWeightAddress, biasStorage, none, doutput);
1196+
linearTransformBwdOnTarget<Storage, Layout, OutputVector, TargetEnum.CUDA, false>(dthis, dWeightStorage, dWeightAddress, biasStorage, none, doutput);
11951197
case spirv:
1196-
linearTransformBwdOnTarget<Storage, OutputVector, TargetEnum.SPIR_V, false>(dthis, dWeightStorage, dWeightAddress, biasStorage, none, doutput);
1198+
linearTransformBwdOnTarget<Storage, Layout, OutputVector, TargetEnum.SPIR_V, false>(dthis, dWeightStorage, dWeightAddress, biasStorage, none, doutput);
11971199
}
11981200
}
11991201

1200-
1201-
12021202
[Differentiable]
12031203
[BackwardDerivative(linearTransformBwd)]
1204-
public OutputVector linearTransform<Storage, OutputVector>(
1204+
public OutputVector linearTransform<Storage, Layout, OutputVector>(
12051205
Storage weightStorage,
12061206
Storage biasStorage,
12071207
no_diff Storage.Address weightAddress,
12081208
no_diff Storage.Address biasAddress)
12091209
where Storage : IStorage<T>
12101210
where Storage.Differential : IStorage<T.Differential>
12111211
where Storage.Address == Storage.Differential.Address
1212+
where Layout : IStorageLayout
12121213
where OutputVector : IVector<T>
12131214
{
12141215
__target_switch
12151216
{
12161217
case cuda:
1217-
return no_diff linearTransformOnTarget<Storage, OutputVector, TargetEnum.CUDA, true>(weightStorage, weightAddress, biasStorage, biasAddress);
1218+
return no_diff linearTransformOnTarget<Storage, Layout, OutputVector, TargetEnum.CUDA, true>(weightStorage, weightAddress, biasStorage, biasAddress);
12181219
case spirv:
1219-
return no_diff linearTransformOnTarget<Storage, OutputVector, TargetEnum.SPIR_V, true>(weightStorage, weightAddress, biasStorage, biasAddress);
1220+
return no_diff linearTransformOnTarget<Storage, Layout, OutputVector, TargetEnum.SPIR_V, true>(weightStorage, weightAddress, biasStorage, biasAddress);
12201221
}
12211222
}
12221223

12231224
// Backward of linear transformation with bias
1224-
static void linearTransformBwd<Storage, OutputVector>(
1225+
static void linearTransformBwd<Storage, Layout, OutputVector>(
12251226
inout DifferentialPair<This> dthis,
12261227
DifferentialPtrPair<Storage> dWeightStorage,
12271228
DifferentialPtrPair<Storage> dBiasStorage,
@@ -1231,34 +1232,37 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
12311232
where Storage : IStorage<T>
12321233
where Storage.Differential : IStorage<T.Differential>
12331234
where Storage.Address == Storage.Differential.Address
1235+
where Layout : IStorageLayout
12341236
where OutputVector : IVector<T>
12351237
where OutputVector.Differential : IVector<T.Differential>
12361238
{
12371239
__target_switch
12381240
{
12391241
case cuda:
1240-
linearTransformBwdOnTarget<Storage, OutputVector, TargetEnum.CUDA, true>( dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
1242+
linearTransformBwdOnTarget<Storage, Layout, OutputVector, TargetEnum.CUDA, true>( dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
12411243
case spirv:
1242-
linearTransformBwdOnTarget<Storage, OutputVector, TargetEnum.SPIR_V, true>(dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
1244+
linearTransformBwdOnTarget<Storage, Layout, OutputVector, TargetEnum.SPIR_V, true>(dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
12431245
}
12441246
}
12451247

12461248
[Differentiable]
1247-
public OutputVector linearTransform<Address, OutputVector>(
1249+
public OutputVector linearTransform<Address, Layout, OutputVector>(
12481250
Address weightAddress)
12491251
where Address : IPointerLikeAddress<T>
12501252
where Address.Differential : IPointerLikeAddress<T.Differential>
1253+
where Layout : IStorageLayout
12511254
where OutputVector : IVector<T>
12521255
{
12531256
OutputVector output = OutputVector();
12541257
return output;
12551258
}
12561259

12571260
[Differentiable]
1258-
public OutputVector linearTransform<Address, OutputVector>(
1261+
public OutputVector linearTransform<Address, Layout, OutputVector>(
12591262
Address weightAddress, Address biasAddress)
12601263
where Address : IPointerLikeAddress<T>
12611264
where Address.Differential : IPointerLikeAddress<T.Differential>
1265+
where Layout : IStorageLayout
12621266
where OutputVector : IVector<T>
12631267
{
12641268
OutputVector output = OutputVector();

source/standard-modules/neural/buffer-storage.slang

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public struct StructuredBufferStorage<T> : IStorage<T>
1818
/// Address type is a simple unsigned integer index.
1919
public typealias Address = uint;
2020

21-
/// The underlying buffer type.
21+
/// The underlying buffer type.s
2222
public typealias BufferType = RWStructuredBuffer<T>;
2323

2424
/// Differential type for automatic differentiation.

source/standard-modules/neural/inline-vector.slang

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ public struct InlineVector<T, int N> : IVector<T>
7676

7777
// Linear transformation without bias
7878
[BackwardDerivative(linearTransformBwd)]
79-
public OutputVector linearTransform<Storage, OutputVector>(
79+
public OutputVector linearTransform<Storage, Layout, OutputVector>(
8080
Storage weightStorage,
8181
no_diff Storage.Address weightAddress)
8282
where Storage : IStorage<T>
8383
where Storage.Differential : IStorage<T.Differential>
8484
where Storage.Address == Storage.Differential.Address
85+
where Layout : IStorageLayout
8586
where OutputVector : IVector<T>
8687
{
8788
OutputVector output = OutputVector();
@@ -105,18 +106,19 @@ public struct InlineVector<T, int N> : IVector<T>
105106

106107
// Linear transformation with bias
107108
[BackwardDerivative(linearTransformBwd)]
108-
public OutputVector linearTransform<Storage, OutputVector>(
109+
public OutputVector linearTransform<Storage, Layout, OutputVector>(
109110
Storage weightStorage,
110111
Storage biasStorage,
111112
no_diff Storage.Address weightAddress,
112113
no_diff Storage.Address biasAddress)
113114
where Storage : IStorage<T>
114115
where Storage.Differential : IStorage<T.Differential>
115116
where Storage.Address == Storage.Differential.Address
117+
where Layout : IStorageLayout
116118
where OutputVector : IVector<T>
117119
{
118120
// Reuse the unbias matmul method
119-
OutputVector output = this.linearTransform<Storage, OutputVector>(weightStorage, weightAddress);
121+
OutputVector output = this.linearTransform<Storage, Layout, OutputVector>(weightStorage, weightAddress);
120122

121123
// apply the bias
122124
[ForceUnroll]
@@ -130,14 +132,15 @@ public struct InlineVector<T, int N> : IVector<T>
130132
}
131133

132134
// Backward of linear transformation without bias
133-
static void linearTransformBwd<Storage, OutputVector>(
135+
static void linearTransformBwd<Storage, Layout, OutputVector>(
134136
inout DifferentialPair<This> dthis,
135137
DifferentialPtrPair<Storage> dWeightStorage,
136138
no_diff Storage.Address dWeightAddress,
137139
OutputVector.Differential doutput)
138140
where Storage : IStorage<T>
139141
where Storage.Differential : IStorage<T.Differential>
140142
where Storage.Address == Storage.Differential.Address
143+
where Layout : IStorageLayout
141144
where OutputVector : IVector<T>
142145
{
143146
// Derivative of the input is transposed weight matrix times the output differential
@@ -180,7 +183,7 @@ public struct InlineVector<T, int N> : IVector<T>
180183
}
181184

182185
// Backward of linear transformation with bias
183-
static void linearTransformBwd<Storage, OutputVector>(
186+
static void linearTransformBwd<Storage, Layout, OutputVector>(
184187
inout DifferentialPair<This> dthis,
185188
DifferentialPtrPair<Storage> dWeightStorage,
186189
DifferentialPtrPair<Storage> dBiasStorage,
@@ -190,10 +193,11 @@ public struct InlineVector<T, int N> : IVector<T>
190193
where Storage : IStorage<T>
191194
where Storage.Differential : IStorage<T.Differential>
192195
where Storage.Address == Storage.Differential.Address
196+
where Layout : IStorageLayout
193197
where OutputVector : IVector<T>
194198
{
195199
// Reuse the unbias backward method
196-
linearTransformBwd<Storage, OutputVector>(dthis, dWeightStorage, dWeightAddress, doutput);
200+
linearTransformBwd<Storage, Layout, OutputVector>(dthis, dWeightStorage, dWeightAddress, doutput);
197201

198202
// Derivative of the bias is the same as the output differential
199203
[ForceUnroll]
@@ -206,10 +210,11 @@ public struct InlineVector<T, int N> : IVector<T>
206210

207211
// Linear transformation without bias (Bindless storage)
208212
[BackwardDerivative(linearTransformBwd)]
209-
public OutputVector linearTransform<Address, OutputVector>(
213+
public OutputVector linearTransform<Address, Layout, OutputVector>(
210214
Address weightAddress)
211215
where Address : IPointerLikeAddress<T>
212216
where Address.Differential : IPointerLikeAddress<T.Differential>
217+
where Layout : IStorageLayout
213218
where OutputVector : IVector<T>
214219
{
215220
var output = OutputVector();
@@ -231,15 +236,16 @@ public struct InlineVector<T, int N> : IVector<T>
231236

232237
// Linear transformation with bias (Bindless storage)
233238
[BackwardDerivative(linearTransformBwd)]
234-
public OutputVector linearTransform<Address, OutputVector>(
239+
public OutputVector linearTransform<Address, Layout, OutputVector>(
235240
Address weightAddress,
236241
Address biasAddress)
237242
where Address : IPointerLikeAddress<T>
238243
where Address.Differential : IPointerLikeAddress<T.Differential>
244+
where Layout : IStorageLayout
239245
where OutputVector : IVector<T>
240246
{
241247
// Reuse the unbias matmul method
242-
OutputVector output = this.linearTransform<Address, OutputVector>(weightAddress);
248+
OutputVector output = this.linearTransform<Address, Layout, OutputVector>(weightAddress);
243249

244250
[ForceUnroll]
245251
for (int i = 0; i < OutputVector.Size; i++)
@@ -249,14 +255,15 @@ public struct InlineVector<T, int N> : IVector<T>
249255
}
250256

251257
// Backward of linear transformation without bias (Bindless storage)
252-
static public void linearTransformBwd<Address, OutputVector>(
258+
static public void linearTransformBwd<Address, Layout, OutputVector>(
253259
inout DifferentialPair<This> dthis,
254260
DifferentialPtrPair<Address> dparameters,
255261
OutputVector.Differential doutput)
256-
where Address : IPointerLikeAddress<T>
257-
where Address.Differential : IPointerLikeAddress<T.Differential>
258-
where OutputVector : IVector<T>
259-
where OutputVector.Differential : IVector<T.Differential>
262+
where Address : IPointerLikeAddress<T>
263+
where Address.Differential : IPointerLikeAddress<T.Differential>
264+
where Layout : IStorageLayout
265+
where OutputVector : IVector<T>
266+
where OutputVector.Differential : IVector<T.Differential>
260267
{
261268
// dInput = dW^T * dOutput
262269
This.Differential d = {};
@@ -292,17 +299,18 @@ public struct InlineVector<T, int N> : IVector<T>
292299
}
293300

294301
// Backward of linear transformation with bias (Bindless storage)
295-
static public void linearTransformBwd<Address, OutputVector>(
302+
static public void linearTransformBwd<Address, Layout, OutputVector>(
296303
inout DifferentialPair<This> dthis,
297304
DifferentialPtrPair<Address> dWeightAddress,
298305
DifferentialPtrPair<Address> dBiasAddress,
299306
OutputVector.Differential doutput)
300-
where Address : IPointerLikeAddress<T>
301-
where Address.Differential : IPointerLikeAddress<T.Differential>
302-
where OutputVector : IVector<T>
307+
where Address : IPointerLikeAddress<T>
308+
where Address.Differential : IPointerLikeAddress<T.Differential>
309+
where Layout : IStorageLayout
310+
where OutputVector : IVector<T>
303311
{
304312
// Reuse the unbias backward method
305-
linearTransformBwd<Address, OutputVector>(dthis, dWeightAddress, doutput);
313+
linearTransformBwd<Address, Layout, OutputVector>(dthis, dWeightAddress, doutput);
306314

307315
let biasOffset = dBiasAddress.d.getOffset(0);
308316
// dBias = dOutput

source/standard-modules/neural/istorages.slang

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
implementing neural;
22

3+
public enum LayoutType : uint32_t
4+
{
5+
Linear = 0,
6+
}
7+
8+
internal interface IStorageLayout
9+
{
10+
internal static const LayoutType Layout;
11+
}
12+
13+
public struct LinearLayout : IStorageLayout
14+
{
15+
internal static const LayoutType Layout = LayoutType.Linear;
16+
}
17+
318
/**
419
Storage interface for accessing neural network parameters.
520
Provides an abstraction for reading/writing parameters from various storage backends

source/standard-modules/neural/ivector.slang

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ public interface IVector<T> : IDifferentiable, IArrayAccessor<T>
7777
- `OutputVector` must conform to `IVector<T>`
7878
*/
7979
[Differentiable]
80-
public OutputVector linearTransform<Storage, OutputVector>(
80+
public OutputVector linearTransform<Storage, Layout, OutputVector>(
8181
Storage weightStorage,
8282
no_diff Storage.Address weightAddress)
8383
where Storage : IStorage<T>
8484
where Storage.Differential : IStorage<T.Differential>
8585
where Storage.Address == Storage.Differential.Address
86+
where Layout : IStorageLayout
8687
where OutputVector : IVector<T>;
8788

8889
/**
@@ -109,14 +110,15 @@ public interface IVector<T> : IDifferentiable, IArrayAccessor<T>
109110
- `OutputVector` must conform to `IVector<T, OutputSize>`
110111
*/
111112
[Differentiable]
112-
public OutputVector linearTransform<Storage, OutputVector>(
113+
public OutputVector linearTransform<Storage, Layout, OutputVector>(
113114
Storage weightStorage,
114115
Storage biasStorage,
115116
no_diff Storage.Address weightAddress,
116117
no_diff Storage.Address biasAddress)
117118
where Storage : IStorage<T>
118119
where Storage.Differential : IStorage<T.Differential>
119120
where Storage.Address == Storage.Differential.Address
121+
where Layout : IStorageLayout
120122
where OutputVector : IVector<T>;
121123

122124
/**
@@ -137,9 +139,10 @@ public interface IVector<T> : IDifferentiable, IArrayAccessor<T>
137139
- `OutputVector` must conform to `IVector<T, OutputSize>`
138140
*/
139141
[Differentiable]
140-
public OutputVector linearTransform<Address, OutputVector>(Address weightAddress)
142+
public OutputVector linearTransform<Address, Layout, OutputVector>(Address weightAddress)
141143
where Address : IPointerLikeAddress<T>
142144
where Address.Differential : IPointerLikeAddress<T.Differential>
145+
where Layout : IStorageLayout
143146
where OutputVector : IVector<T>;
144147

145148
/**
@@ -165,9 +168,10 @@ public interface IVector<T> : IDifferentiable, IArrayAccessor<T>
165168
- `OutputVector` must conform to `IVector<T, OutputSize>`
166169
*/
167170
[Differentiable]
168-
public OutputVector linearTransform<Address, OutputVector>(
171+
public OutputVector linearTransform<Address, Layout, OutputVector>(
169172
Address weightAddress, Address biasAddress)
170173
where Address : IPointerLikeAddress<T>
171174
where Address.Differential : IPointerLikeAddress<T.Differential>
175+
where Layout : IStorageLayout
172176
where OutputVector : IVector<T>;
173177
}

0 commit comments

Comments
 (0)