Skip to content

Commit 7313516

Browse files
committed
resolve pr comments and collate ints to int2
1 parent ee6104d commit 7313516

File tree

2 files changed

+68
-58
lines changed

2 files changed

+68
-58
lines changed

torchvision/csrc/ops/mps/deform_conv2d_kernel.mm

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,30 @@
3535
TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D");
3636
TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true");
3737
TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor");
38+
TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor");
39+
TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor");
40+
TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor");
41+
TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor");
3842

3943
at::DeviceGuard guard(input_c.device());
4044

41-
int batch = input_c.size(0);
42-
int in_channels = input_c.size(1);
43-
int in_h = input_c.size(2);
44-
int in_w = input_c.size(3);
45-
int weight_h = weight_c.size(2);
46-
int weight_w = weight_c.size(3);
47-
int out_channels = weight_c.size(0);
48-
int ker_h = dilation_h * (weight_h - 1) + 1;
49-
int ker_w = dilation_w * (weight_w - 1) + 1;
50-
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
51-
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
45+
uint32_t batch = input_c.size(0);
46+
uint32_t in_channels = input_c.size(1);
47+
uint32_t in_h = input_c.size(2);
48+
uint32_t in_w = input_c.size(3);
49+
uint32_t weight_h = weight_c.size(2);
50+
uint32_t weight_w = weight_c.size(3);
51+
uint32_t out_channels = weight_c.size(0);
52+
uint32_t ker_h = dilation_h * (weight_h - 1) + 1;
53+
uint32_t ker_w = dilation_w * (weight_w - 1) + 1;
54+
uint32_t out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
55+
uint32_t out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
56+
uint32_t pad_h_u = static_cast<uint32_t>(pad_h);
57+
uint32_t pad_w_u = static_cast<uint32_t>(pad_w);
58+
uint32_t stride_h_u = static_cast<uint32_t>(stride_h);
59+
uint32_t stride_w_u = static_cast<uint32_t>(stride_w);
60+
uint32_t dilation_h_u = static_cast<uint32_t>(dilation_h);
61+
uint32_t dilation_w_u = static_cast<uint32_t>(dilation_w);
5262

5363
TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels,
5464
"Input channels (", in_channels,
@@ -103,8 +113,13 @@
103113
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
104114
[computeEncoder setComputePipelineState:pipelineState];
105115
at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer,
106-
in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w,
107-
dilation_h, dilation_w, batch, in_channels, n_offset_grps, out_h, out_w,
116+
std::array<uint32_t, 2>{in_h, in_w},
117+
std::array<uint32_t, 2>{weight_h, weight_w},
118+
std::array<uint32_t, 2>{pad_h_u, pad_w_u},
119+
std::array<uint32_t, 2>{stride_h_u, stride_w_u},
120+
std::array<uint32_t, 2>{dilation_h_u, dilation_w_u},
121+
batch, in_channels, n_offset_grps,
122+
std::array<uint32_t, 2>{out_h, out_w},
108123
use_mask, outputBuffer);
109124
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
110125
}

torchvision/csrc/ops/mps/mps_kernels.h

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -277,28 +277,29 @@ kernel void deformable_im2col_kernel(
277277
constant T* input_ptr [[ buffer(0) ]],
278278
constant T* offset_ptr [[ buffer(1) ]],
279279
constant T* mask_ptr [[ buffer(2) ]],
280-
constant int& height [[ buffer(3) ]],
281-
constant int& width [[ buffer(4) ]],
282-
constant int& weight_h [[ buffer(5) ]],
283-
constant int& weight_w [[ buffer(6) ]],
284-
constant int& pad_h [[ buffer(7) ]],
285-
constant int& pad_w [[ buffer(8) ]],
286-
constant int& stride_h [[ buffer(9) ]],
287-
constant int& stride_w [[ buffer(10)]],
288-
constant int& dilation_h [[ buffer(11)]],
289-
constant int& dilation_w [[ buffer(12)]],
290-
constant int& batch_size [[ buffer(13)]],
291-
constant int& n_in_channels [[ buffer(14)]],
292-
constant int& n_offset_grps [[ buffer(15)]],
293-
constant int& out_h [[ buffer(16)]],
294-
constant int& out_w [[ buffer(17)]],
295-
constant bool& use_mask [[ buffer(18)]],
296-
device T* columns_ptr [[ buffer(19)]],
280+
constant int2& input_size [[ buffer(3) ]], // (height, width)
281+
constant int2& weight_size [[ buffer(4) ]], // (weight_h, weight_w)
282+
constant int2& pad [[ buffer(5) ]], // (pad_h, pad_w)
283+
constant int2& stride [[ buffer(6) ]], // (stride_h, stride_w)
284+
constant int2& dilation [[ buffer(7) ]], // (dilation_h, dilation_w)
285+
constant int& batch_size [[ buffer(8) ]],
286+
constant int& n_in_channels [[ buffer(9) ]],
287+
constant int& n_offset_grps [[ buffer(10)]],
288+
constant int2& out_size [[ buffer(11)]], // (out_h, out_w)
289+
constant bool& use_mask [[ buffer(12)]],
290+
device T* columns_ptr [[ buffer(13)]],
297291
uint tid [[ thread_position_in_grid ]],
298-
uint tpg [[ threads_per_grid ]])
292+
uint tpg [[ threads_per_grid ]]
293+
)
299294
{
295+
int height = input_size.x, width = input_size.y;
296+
int weight_h = weight_size.x, weight_w = weight_size.y;
297+
int pad_h = pad.x, pad_w = pad.y;
298+
int stride_h = stride.x, stride_w = stride.y;
299+
int dilation_h = dilation.x, dilation_w = dilation.y;
300+
int out_h = out_size.x, out_w = out_size.y;
301+
300302
int total = out_w * out_h * batch_size * n_in_channels;
301-
int gridSize = tpg;
302303
if (tid >= total) {
303304
return;
304305
}
@@ -355,32 +356,26 @@ kernel void deformable_im2col_kernel(
355356
}
356357
}
357358
358-
#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \
359-
template \
360-
[[host_name("deformable_im2col_" #DTYPE)]] \
361-
kernel void deformable_im2col_kernel<DTYPE>( \
362-
constant DTYPE* input_ptr [[ buffer(0) ]], \
363-
constant DTYPE* offset_ptr [[ buffer(1) ]], \
364-
constant DTYPE* mask_ptr [[ buffer(2) ]], \
365-
constant int& height [[ buffer(3) ]], \
366-
constant int& width [[ buffer(4) ]], \
367-
constant int& weight_h [[ buffer(5) ]], \
368-
constant int& weight_w [[ buffer(6) ]], \
369-
constant int& pad_h [[ buffer(7) ]], \
370-
constant int& pad_w [[ buffer(8) ]], \
371-
constant int& stride_h [[ buffer(9) ]], \
372-
constant int& stride_w [[ buffer(10)]], \
373-
constant int& dilation_h [[ buffer(11)]], \
374-
constant int& dilation_w [[ buffer(12)]], \
375-
constant int& batch_sz [[ buffer(13)]], \
376-
constant int& n_in_channels[[ buffer(14)]], \
377-
constant int& n_offset_grps[[ buffer(15)]], \
378-
constant int& out_h [[ buffer(16)]], \
379-
constant int& out_w [[ buffer(17)]], \
380-
constant bool& use_mask [[ buffer(18)]], \
381-
device DTYPE* columns_ptr [[ buffer(19)]], \
382-
uint tid [[ thread_position_in_grid ]], \
383-
uint tpg [[ threads_per_grid ]]);
359+
#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \
360+
template \
361+
[[host_name("deformable_im2col_" #DTYPE)]] \
362+
kernel void deformable_im2col_kernel<DTYPE>( \
363+
constant DTYPE* input_ptr [[ buffer(0) ]], \
364+
constant DTYPE* offset_ptr [[ buffer(1) ]], \
365+
constant DTYPE* mask_ptr [[ buffer(2) ]], \
366+
constant int2& input_size [[ buffer(3) ]], /* (h, w) */ \
367+
constant int2& weight_size [[ buffer(4) ]], /* (h, w) */ \
368+
constant int2& pad [[ buffer(5) ]], /* (h, w) */ \
369+
constant int2& stride [[ buffer(6) ]], /* (h, w) */ \
370+
constant int2& dilation [[ buffer(7) ]], /* (h, w) */ \
371+
constant int& batch_size [[ buffer(8) ]], \
372+
constant int& n_in_channels [[ buffer(9) ]], \
373+
constant int& n_offset_grps [[ buffer(10)]], \
374+
constant int2& out_size [[ buffer(11)]], /* (h, w) */ \
375+
constant bool& use_mask [[ buffer(12)]], \
376+
device DTYPE* columns_ptr [[ buffer(13)]], \
377+
uint tid [[ thread_position_in_grid ]], \
378+
uint tpg [[ threads_per_grid ]]);
384379
385380
template<typename T, typename integer_t>
386381
kernel void roi_align(

0 commit comments

Comments
 (0)