@@ -5,7 +5,7 @@ namespace ops {
55
66namespace  mps  {
77
8- static  const   char * METAL_VISION =  R"VISION_METAL( 
8+ static  at::native::mps::MetalShaderLibrary  lib ( R"VISION_METAL( 
99
1010#include <metal_atomic> 
1111#include <metal_stdlib> 
@@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) {
2626  return (n + m - 1) / m; 
2727} 
2828
29- template <typename T> 
30- inline void atomic_add_float( device T* data_ptr, const T val) 
29+ inline void atomic_add_float(device float* data_ptr, const float val) 
3130{ 
32- #if __METAL_VERSION__ >= 300 
33-   // atomic_float is supported in Metal 3 (macOS Ventura) onward. 
34-   device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); 
35- #else 
36-   // Custom atomic addition implementation 
37-   // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 
38-   // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 
39-   // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) 
40-    
41-   // Create an atomic uint pointer for atomic transaction. 
42-   device atomic_uint* atom_var = (device atomic_uint*)data_ptr; 
43-   // Create necessary storage. 
44-   uint  fetched_uint,  assigning_uint; 
45-   T fetched_float, assigning_float; 
46- 
47-   // Replace the value in atom_var with 0 and return the previous value in atom_var. 
48-   fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); 
49-   // Read out the previous value as float. 
50-   fetched_float = *( (thread T*) &fetched_uint ); 
51- 
52-   // Do addition and represent the addition result in uint for atomic transaction. 
53-   assigning_float = fetched_float + val; 
54-   assigning_uint =  *((thread uint*) &assigning_float); 
55- 
56-   // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). 
57-   while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0)  { 
58-     // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. 
59-     // Try to assign 0 and get the previously assigned addition result. 
60-     uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); 
61-     T fetched_float_again = *( (thread T*) &fetched_uint_again ); 
62-     // Re-add again 
63-     fetched_float = *((thread T*) &(fetched_uint)); 
64-     // Previously assigned addition result + addition result from other threads. 
65-     assigning_float = fetched_float_again + fetched_float; 
66-     assigning_uint =  *( (thread uint*) &assigning_float); 
67-   } 
68- #endif 
31+   atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); 
32+ } 
33+ 
34+ 
35+ inline void atomic_add_float(device half* data_ptr, const half val) 
36+ { 
37+   atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast<float>(val), memory_order_relaxed); 
6938} 
7039
7140template <typename T, typename integer_t> 
@@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t);
10611030REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); 
10621031REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); 
10631032
1064- )VISION_METAL" 
1065- 
1066- static  id<MTLLibrary> compileVisionOpsLibrary (id<MTLDevice> device) {
1067-   static  id<MTLLibrary> visionLibrary = nil;
1068-   if  (visionLibrary) {
1069-     return  visionLibrary;
1070-   }
1071- 
1072-   NSError* error = nil;
1073-   MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
1074-   [options setLanguageVersion:MTLLanguageVersion2_3];
1075-   visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
1076-                                        options:options
1077-                                          error:&error];
1078-   TORCH_CHECK (visionLibrary, " Failed to create metal vision library, error: " 
1079-   return  visionLibrary;
1080- }
1081- 
1082- static  id<MTLComputePipelineState> visionPipelineState (id<MTLDevice> device, const  std::string& kernel) {
1083-   static  std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
1084-   id<MTLComputePipelineState> pso = psoCache[kernel];
1085-   if  (pso) {
1086-     return  pso;
1087-   }
1088- 
1089-   NSError* error = nil;
1090-   id<MTLLibrary> visionLib = compileVisionOpsLibrary (device);
1091-   id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str ()]];
1092-   TORCH_CHECK (visionFunc, " Failed to create function state object for: " 
1093-   pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
1094-   TORCH_CHECK (pso, " Failed to created pipeline state object, error: " 
1033+ )VISION_METAL" 
10951034
1096-   psoCache[kernel] = pso;
1097-   return  pso;
1035+ static  id<MTLComputePipelineState> visionPipelineState (
1036+     id<MTLDevice> device,
1037+     const  std::string& kernel) {
1038+   return  lib.getPipelineStateForFunc (kernel);
10981039}
10991040
11001041} //  namespace mps
0 commit comments