@@ -711,7 +711,7 @@ VISIBILITY_LEVEL struct MMAHelper<T, int InputSize, int OutputSize, int Subgroup
711711 break ;
712712
713713 let offset = Storage .getOffset (biasAddress, index);
714- biasStorage .write (offset, localResult [index]);
714+ biasStorage .atomicAdd (offset, localResult [index]);
715715 }
716716 }
717717 }
@@ -1073,7 +1073,7 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
10731073 set { this .data [index] = newValue; }
10741074 }
10751075
1076- private OutputVector linearTransformOnTarget< Storage, OutputVector, TargetEnum Target, bool Bias> (
1076+ private OutputVector linearTransformOnTarget< Storage, Layout, OutputVector, TargetEnum Target, bool Bias> (
10771077 Storage weight,
10781078 no_diff Storage .Address weightAddress,
10791079 no_diff Optional< Storage> bias,
@@ -1093,7 +1093,7 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
10931093 return OutputVector(outputArray);
10941094 }
10951095
1096- private static void linearTransformBwdOnTarget< Storage, OutputVector, TargetEnum Target, bool Bias> (
1096+ private static void linearTransformBwdOnTarget< Storage, Layout , OutputVector, TargetEnum Target, bool Bias> (
10971097 inout DifferentialPair< This> dthis,
10981098 DifferentialPtrPair< Storage> dWeightStorage,
10991099 no_diff Storage .Address dWeightAddress,
@@ -1158,70 +1158,71 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
11581158 // Linear transformation without bias
11591159 [Differentiable]
11601160 [BackwardDerivative(linearTransformBwd)]
1161- public OutputVector linearTransform< Storage, OutputVector> (
1161+ public OutputVector linearTransform< Storage, Layout, OutputVector> (
11621162 Storage weightStorage,
11631163 no_diff Storage .Address weightAddress)
11641164 where Storage : IStorage< T>
11651165 where Storage .Differential : IStorage< T .Differential >
11661166 where Storage .Address == Storage .Differential .Address
1167+ where Layout : IStorageLayout
11671168 where OutputVector : IVector< T>
11681169 {
11691170 __target_switch
11701171 {
11711172 case cuda:
1172- return no_diff linearTransformOnTarget< Storage, OutputVector, TargetEnum .CUDA , false > (weightStorage, weightAddress, none , none );
1173+ return no_diff linearTransformOnTarget< Storage, Layout, OutputVector, TargetEnum .CUDA , false > (weightStorage, weightAddress, none , none );
11731174 case spirv:
1174- return no_diff linearTransformOnTarget< Storage, OutputVector, TargetEnum .SPIR_V , false > (weightStorage, weightAddress, none , none );
1175+ return no_diff linearTransformOnTarget< Storage, Layout, OutputVector, TargetEnum .SPIR_V , false > (weightStorage, weightAddress, none , none );
11751176 }
11761177 }
11771178
11781179 // Backward of linear transformation without bias
1179- static void linearTransformBwd< Storage, OutputVector> (
1180+ static void linearTransformBwd< Storage, Layout, OutputVector> (
11801181 inout DifferentialPair< This> dthis,
11811182 DifferentialPtrPair< Storage> dWeightStorage,
11821183 no_diff Storage .Address dWeightAddress,
11831184 OutputVector .Differential doutput)
11841185 where Storage : IStorage< T>
11851186 where Storage .Differential : IStorage< T .Differential >
11861187 where Storage .Address == Storage .Differential .Address
1188+ where Layout : IStorageLayout
11871189 where OutputVector : IVector< T>
11881190 where OutputVector .Differential : IVector< T .Differential >
11891191 {
11901192 Optional< DifferentialPtrPair< Storage>> biasStorage = {};
11911193 __target_switch
11921194 {
11931195 case cuda:
1194- linearTransformBwdOnTarget< Storage, OutputVector, TargetEnum .CUDA , false > (dthis, dWeightStorage, dWeightAddress, biasStorage, none , doutput);
1196+ linearTransformBwdOnTarget< Storage, Layout, OutputVector, TargetEnum .CUDA , false > (dthis, dWeightStorage, dWeightAddress, biasStorage, none , doutput);
11951197 case spirv:
1196- linearTransformBwdOnTarget< Storage, OutputVector, TargetEnum .SPIR_V , false > (dthis, dWeightStorage, dWeightAddress, biasStorage, none , doutput);
1198+ linearTransformBwdOnTarget< Storage, Layout, OutputVector, TargetEnum .SPIR_V , false > (dthis, dWeightStorage, dWeightAddress, biasStorage, none , doutput);
11971199 }
11981200 }
11991201
1200-
1201-
12021202 [Differentiable]
12031203 [BackwardDerivative(linearTransformBwd)]
1204- public OutputVector linearTransform< Storage, OutputVector> (
1204+ public OutputVector linearTransform< Storage, Layout, OutputVector> (
12051205 Storage weightStorage,
12061206 Storage biasStorage,
12071207 no_diff Storage .Address weightAddress,
12081208 no_diff Storage .Address biasAddress)
12091209 where Storage : IStorage< T>
12101210 where Storage .Differential : IStorage< T .Differential >
12111211 where Storage .Address == Storage .Differential .Address
1212+ where Layout : IStorageLayout
12121213 where OutputVector : IVector< T>
12131214 {
12141215 __target_switch
12151216 {
12161217 case cuda:
1217- return no_diff linearTransformOnTarget< Storage, OutputVector, TargetEnum .CUDA , true > (weightStorage, weightAddress, biasStorage, biasAddress);
1218+ return no_diff linearTransformOnTarget< Storage, Layout, OutputVector, TargetEnum .CUDA , true > (weightStorage, weightAddress, biasStorage, biasAddress);
12181219 case spirv:
1219- return no_diff linearTransformOnTarget< Storage, OutputVector, TargetEnum .SPIR_V , true > (weightStorage, weightAddress, biasStorage, biasAddress);
1220+ return no_diff linearTransformOnTarget< Storage, Layout, OutputVector, TargetEnum .SPIR_V , true > (weightStorage, weightAddress, biasStorage, biasAddress);
12201221 }
12211222 }
12221223
12231224 // Backward of linear transformation with bias
1224- static void linearTransformBwd< Storage, OutputVector> (
1225+ static void linearTransformBwd< Storage, Layout, OutputVector> (
12251226 inout DifferentialPair< This> dthis,
12261227 DifferentialPtrPair< Storage> dWeightStorage,
12271228 DifferentialPtrPair< Storage> dBiasStorage,
@@ -1231,34 +1232,37 @@ public struct AccelerateVectorCoopMat<T, ShMemPool : ISharedMemoryPool, int N, i
12311232 where Storage : IStorage< T>
12321233 where Storage .Differential : IStorage< T .Differential >
12331234 where Storage .Address == Storage .Differential .Address
1235+ where Layout : IStorageLayout
12341236 where OutputVector : IVector< T>
12351237 where OutputVector .Differential : IVector< T .Differential >
12361238 {
12371239 __target_switch
12381240 {
12391241 case cuda:
1240- linearTransformBwdOnTarget< Storage, OutputVector, TargetEnum .CUDA , true > ( dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
1242+ linearTransformBwdOnTarget< Storage, Layout, OutputVector, TargetEnum .CUDA , true > ( dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
12411243 case spirv:
1242- linearTransformBwdOnTarget< Storage, OutputVector, TargetEnum .SPIR_V , true > (dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
1244+ linearTransformBwdOnTarget< Storage, Layout, OutputVector, TargetEnum .SPIR_V , true > (dthis, dWeightStorage, dWeightAddress, dBiasStorage, dBiasAddress, doutput);
12431245 }
12441246 }
12451247
12461248 [Differentiable]
1247- public OutputVector linearTransform< Address, OutputVector> (
1249+ public OutputVector linearTransform< Address, Layout, OutputVector> (
12481250 Address weightAddress)
12491251 where Address : IPointerLikeAddress< T>
12501252 where Address .Differential : IPointerLikeAddress< T .Differential >
1253+ where Layout : IStorageLayout
12511254 where OutputVector : IVector< T>
12521255 {
12531256 OutputVector output = OutputVector();
12541257 return output;
12551258 }
12561259
12571260 [Differentiable]
1258- public OutputVector linearTransform< Address, OutputVector> (
1261+ public OutputVector linearTransform< Address, Layout, OutputVector> (
12591262 Address weightAddress, Address biasAddress)
12601263 where Address : IPointerLikeAddress< T>
12611264 where Address .Differential : IPointerLikeAddress< T .Differential >
1265+ where Layout : IStorageLayout
12621266 where OutputVector : IVector< T>
12631267 {
12641268 OutputVector output = OutputVector();
0 commit comments