11#include  < ATen/native/mps/OperationUtils.h> 
22
3- namespace  vision  {
4- namespace  ops  {
5- 
6- namespace  mps  {
3+ namespace  vision ::ops::mps {
74
85static  at::native::mps::MetalShaderLibrary lib (R"VISION_METAL( 
96
@@ -115,15 +112,15 @@ inline T bilinear_interpolate_deformable_conv2d(
115112  T v1 = 0; 
116113  if (y_low >= 0 && x_low >= 0) 
117114    v1 = input[y_low * width + x_low]; 
118-    
115+ 
119116  T v2 = 0; 
120117  if (y_low >= 0 && x_high <= width - 1) 
121118    v2 = input[y_low * width + x_high]; 
122-    
119+ 
123120  T v3 = 0; 
124121  if (y_high <= height - 1 && x_low >= 0) 
125122    v3 = input[y_high * width + x_low]; 
126-    
123+ 
127124  T v4 = 0; 
128125  if (y_high <= height - 1 && x_high <= width - 1) 
129126    v4 = input[y_high * width + x_high]; 
@@ -228,7 +225,7 @@ kernel void nms(constant  T        * dev_boxes     [[buffer(0)]],
228225                constant  float    & iou_threshold [[buffer(3)]], 
229226                uint2     tgid     [[threadgroup_position_in_grid]], 
230227                uint2     tid2     [[thread_position_in_threadgroup]]) { 
231-    
228+ 
232229  const uint row_start = tgid.y; 
233230  const uint col_start = tgid.x; 
234231  const uint tid = tid2.x; 
@@ -245,7 +242,7 @@ kernel void nms(constant  T        * dev_boxes     [[buffer(0)]],
245242    const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; 
246243    uint64_t t = 0; 
247244    uint start = 0; 
248-      
245+ 
249246    if (row_start == col_start) { 
250247      start = tid + 1; 
251248    } 
@@ -309,48 +306,48 @@ kernel void deformable_im2col_kernel(
309306    int out_b = (tid / (out_w * out_h)) % batch_size; 
310307    int in_c  = tid / (out_w * out_h * batch_size); 
311308    int out_c = in_c * weight_h * weight_w; 
312-      
309+ 
313310    int c_per_offset_grp = n_in_channels / n_offset_grps; 
314311    int grp_idx = in_c / c_per_offset_grp; 
315-      
312+ 
316313    int col_offset = out_c * (batch_size * out_h * out_w) 
317314                      + out_b * (out_h * out_w) 
318315                      + out_y * out_w + out_x; 
319316    device T* local_columns_ptr = columns_ptr + col_offset; 
320-      
317+ 
321318    int input_offset = out_b * (n_in_channels * height * width) 
322319                        + in_c * (height * width); 
323320    constant T* local_input_ptr = input_ptr + input_offset; 
324-      
321+ 
325322    int offset_offset = (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; 
326323    constant T* local_offset_ptr = offset_ptr + offset_offset; 
327-      
324+ 
328325    constant T* local_mask_ptr = nullptr; 
329326    if (use_mask) { 
330327        int mask_offset = (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; 
331328        local_mask_ptr = mask_ptr + mask_offset; 
332329    } 
333-      
330+ 
334331    for (int i = 0; i < weight_h; ++i) { 
335332        for (int j = 0; j < weight_w; ++j) { 
336333            int mask_index = i * weight_w + j; 
337334            int offset_index = 2 * mask_index; 
338-              
335+ 
339336            T mask_value = 1; 
340337            if (use_mask) { 
341338                mask_value = local_mask_ptr[mask_index * (out_h * out_w) + out_y * out_w + out_x]; 
342339            } 
343-              
340+ 
344341            T offset_h_val = local_offset_ptr[offset_index * (out_h * out_w) + out_y * out_w + out_x]; 
345342            T offset_w_val = local_offset_ptr[(offset_index + 1) * (out_h * out_w) + out_y * out_w + out_x]; 
346-              
343+ 
347344            T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h_val; 
348345            T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w_val; 
349-              
346+ 
350347            T interp = bilinear_interpolate_deformable_conv2d(local_input_ptr, height, width, y, x, tid); 
351-              
348+ 
352349            *local_columns_ptr = mask_value * interp; 
353-              
350+ 
354351            local_columns_ptr += batch_size * out_h * out_w; 
355352        } 
356353    } 
@@ -584,7 +581,7 @@ kernel void roi_align_backward(
584581          atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast<T>(g2)); 
585582          atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast<T>(g3)); 
586583          atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast<T>(g4)); 
587-            
584+ 
588585        } // if 
589586      } // ix 
590587    } // iy 
@@ -742,7 +739,6 @@ kernel void roi_pool_backward(
742739    if (argmax != -1) { 
743740      atomic_add_float(grad_input + offset + argmax, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride])); 
744741    } 
745-      
746742  } // MPS_1D_KERNEL_LOOP 
747743} 
748744
@@ -1139,7 +1135,6 @@ kernel void ps_roi_pool_backward(
11391135        atomic_add_float(grad_input + offset + grad_input_index, diff_val); 
11401136      } 
11411137    } 
1142-      
11431138  } // MPS_1D_KERNEL_LOOP 
11441139} 
11451140
@@ -1157,7 +1152,7 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>(          \
11571152    constant int64_t & width           [[buffer(7)]],        \ 
11581153    constant int64_t & pooled_height   [[buffer(8)]],        \ 
11591154    constant int64_t & pooled_width    [[buffer(9)]],        \ 
1160-     constant int64_t & channels_out    [[buffer(10)]],       \   
1155+     constant int64_t & channels_out    [[buffer(10)]],       \ 
11611156    constant float   & spatial_scale   [[buffer(11)]],       \ 
11621157    uint2     tgid   [[threadgroup_position_in_grid]],       \ 
11631158    uint2     tptg   [[threads_per_threadgroup]],            \ 
@@ -1192,6 +1187,4 @@ static id<MTLComputePipelineState> visionPipelineState(
11921187  return  lib.getPipelineStateForFunc (kernel);
11931188}
11941189
1195- } //  namespace mps
1196- } //  namespace ops
1197- } //  namespace vision
1190+ } //  namespace vision::ops::mps
0 commit comments