Skip to content

Commit 8235835

Browse files
r-barnesfacebook-github-bot
authored andcommitted
[codemod] Fix signed shift error in caffe2/torch/fb/retrieval/fused_kmean_ann_jagged_output.cu
Summary: Bit-shifting signed variables is scary because the bit-shift can result in the signed variable becoming negative, which is often unwanted. This diff fixes such a situation, which is often as a simple as changing `1 << 31` to `1ul << 31` or an equivalent construction. - If you approve of this diff, please use the "Accept & Ship" button :-) Reviewed By: wenxin0319 Differential Revision: D72536140 fbshipit-source-id: 887255469368f5b4b500dbb42ff7f3307e6f68bb
1 parent ad47246 commit 8235835

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,11 @@ void deformable_im2col(
241241
// https://github.com/pytorch/vision/issues/4269
242242
bool use_64bits_indexing = false;
243243
// Checks if num_kernels or columns numel larger than 2 ** 31
244-
use_64bits_indexing |= num_kernels > (1 << 31);
244+
use_64bits_indexing |= num_kernels > std::numeric_limits<int32_t>::max();
245245
use_64bits_indexing |=
246246
((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h *
247247
out_w >
248-
(1 << 31));
248+
std::numeric_limits<int32_t>::max());
249249

250250
if (use_64bits_indexing) {
251251
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -432,7 +432,7 @@ void compute_grad_input(
432432
// https://github.com/pytorch/vision/issues/4269
433433
bool use_64bits_indexing = false;
434434
// Checks if num_kernels or columns numel larger than 2 ** 31
435-
use_64bits_indexing |= num_kernels > (1 << 31);
435+
use_64bits_indexing |= num_kernels > std::numeric_limits<int32_t>::max();
436436

437437
at::globalContext().alertNotDeterministic("compute_grad_input");
438438

@@ -673,10 +673,10 @@ void compute_grad_offset_and_mask(
673673
// https://github.com/pytorch/vision/issues/4269
674674
bool use_64bits_indexing = false;
675675
// Checks if columns numel is larger than 2 ** 31
676-
use_64bits_indexing |= num_kernels > (1 << 31);
676+
use_64bits_indexing |= num_kernels > std::numeric_limits<int32_t>::max();
677677
use_64bits_indexing |=
678678
((int64_t)channels * weight_h * weight_w * parallel_imgs * out_h * out_w >
679-
(1 << 31));
679+
std::numeric_limits<int32_t>::max());
680680

681681
if (use_64bits_indexing) {
682682
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

0 commit comments

Comments
 (0)