@@ -277,28 +277,29 @@ kernel void deformable_im2col_kernel(
277277 constant T* input_ptr [[ buffer(0) ]],
278278 constant T* offset_ptr [[ buffer(1) ]],
279279 constant T* mask_ptr [[ buffer(2) ]],
280- constant int& height [[ buffer(3) ]],
281- constant int& width [[ buffer(4) ]],
282- constant int& weight_h [[ buffer(5) ]],
283- constant int& weight_w [[ buffer(6) ]],
284- constant int& pad_h [[ buffer(7) ]],
285- constant int& pad_w [[ buffer(8) ]],
286- constant int& stride_h [[ buffer(9) ]],
287- constant int& stride_w [[ buffer(10)]],
288- constant int& dilation_h [[ buffer(11)]],
289- constant int& dilation_w [[ buffer(12)]],
290- constant int& batch_size [[ buffer(13)]],
291- constant int& n_in_channels [[ buffer(14)]],
292- constant int& n_offset_grps [[ buffer(15)]],
293- constant int& out_h [[ buffer(16)]],
294- constant int& out_w [[ buffer(17)]],
295- constant bool& use_mask [[ buffer(18)]],
296- device T* columns_ptr [[ buffer(19)]],
280+ constant int2& input_size [[ buffer(3) ]], // (height, width)
281+ constant int2& weight_size [[ buffer(4) ]], // (weight_h, weight_w)
282+ constant int2& pad [[ buffer(5) ]], // (pad_h, pad_w)
283+ constant int2& stride [[ buffer(6) ]], // (stride_h, stride_w)
284+ constant int2& dilation [[ buffer(7) ]], // (dilation_h, dilation_w)
285+ constant int& batch_size [[ buffer(8) ]],
286+ constant int& n_in_channels [[ buffer(9) ]],
287+ constant int& n_offset_grps [[ buffer(10)]],
288+ constant int2& out_size [[ buffer(11)]], // (out_h, out_w)
289+ constant bool& use_mask [[ buffer(12)]],
290+ device T* columns_ptr [[ buffer(13)]],
297291 uint tid [[ thread_position_in_grid ]],
298- uint tpg [[ threads_per_grid ]])
292+ uint tpg [[ threads_per_grid ]]
293+ )
299294{
295+ int height = input_size.x, width = input_size.y;
296+ int weight_h = weight_size.x, weight_w = weight_size.y;
297+ int pad_h = pad.x, pad_w = pad.y;
298+ int stride_h = stride.x, stride_w = stride.y;
299+ int dilation_h = dilation.x, dilation_w = dilation.y;
300+ int out_h = out_size.x, out_w = out_size.y;
301+
300302 int total = out_w * out_h * batch_size * n_in_channels;
301- int gridSize = tpg;
302303 if (tid >= total) {
303304 return;
304305 }
@@ -355,32 +356,26 @@ kernel void deformable_im2col_kernel(
355356 }
356357}
357358
358- #define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \
359- template \
360- [[host_name("deformable_im2col_" #DTYPE)]] \
361- kernel void deformable_im2col_kernel<DTYPE>( \
362- constant DTYPE* input_ptr [[ buffer(0) ]], \
363- constant DTYPE* offset_ptr [[ buffer(1) ]], \
364- constant DTYPE* mask_ptr [[ buffer(2) ]], \
365- constant int& height [[ buffer(3) ]], \
366- constant int& width [[ buffer(4) ]], \
367- constant int& weight_h [[ buffer(5) ]], \
368- constant int& weight_w [[ buffer(6) ]], \
369- constant int& pad_h [[ buffer(7) ]], \
370- constant int& pad_w [[ buffer(8) ]], \
371- constant int& stride_h [[ buffer(9) ]], \
372- constant int& stride_w [[ buffer(10)]], \
373- constant int& dilation_h [[ buffer(11)]], \
374- constant int& dilation_w [[ buffer(12)]], \
375- constant int& batch_sz [[ buffer(13)]], \
376- constant int& n_in_channels[[ buffer(14)]], \
377- constant int& n_offset_grps[[ buffer(15)]], \
378- constant int& out_h [[ buffer(16)]], \
379- constant int& out_w [[ buffer(17)]], \
380- constant bool& use_mask [[ buffer(18)]], \
381- device DTYPE* columns_ptr [[ buffer(19)]], \
382- uint tid [[ thread_position_in_grid ]], \
383- uint tpg [[ threads_per_grid ]]);
359+ #define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \
360+ template \
361+ [[host_name("deformable_im2col_" #DTYPE)]] \
362+ kernel void deformable_im2col_kernel<DTYPE>( \
363+ constant DTYPE* input_ptr [[ buffer(0) ]], \
364+ constant DTYPE* offset_ptr [[ buffer(1) ]], \
365+ constant DTYPE* mask_ptr [[ buffer(2) ]], \
366+ constant int2& input_size [[ buffer(3) ]], /* (h, w) */ \
367+ constant int2& weight_size [[ buffer(4) ]], /* (h, w) */ \
368+ constant int2& pad [[ buffer(5) ]], /* (h, w) */ \
369+ constant int2& stride [[ buffer(6) ]], /* (h, w) */ \
370+ constant int2& dilation [[ buffer(7) ]], /* (h, w) */ \
371+ constant int& batch_size [[ buffer(8) ]], \
372+ constant int& n_in_channels [[ buffer(9) ]], \
373+ constant int& n_offset_grps [[ buffer(10)]], \
374+ constant int2& out_size [[ buffer(11)]], /* (h, w) */ \
375+ constant bool& use_mask [[ buffer(12)]], \
376+ device DTYPE* columns_ptr [[ buffer(13)]], \
377+ uint tid [[ thread_position_in_grid ]], \
378+ uint tpg [[ threads_per_grid ]]);
384379
385380template<typename T, typename integer_t>
386381kernel void roi_align(
0 commit comments