Skip to content

Commit ffb7a08

Browse files
zeshengzongpytorchmergebot
authored andcommitted
1 parent 356fc41 commit ffb7a08

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

aten/src/ATen/native/cuda/SummaryOps.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,10 @@ Tensor _histc_cuda_template(
320320
std::nullopt /* layout */,
321321
DeviceType::CUDA,
322322
std::nullopt /* pin_memory */);
323-
input_t minvalue = min;
324-
input_t maxvalue = max;
323+
using bounds_t = at::acc_type<input_t, /*is_cuda=*/true>;
324+
bounds_t minvalue = min;
325+
bounds_t maxvalue = max;
326+
325327
if (min == max && self.numel() > 0) {
326328
minvalue = *self.min().cpu().const_data_ptr<input_t>();
327329
maxvalue = *self.max().cpu().const_data_ptr<input_t>();

test/test_reductions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3116,6 +3116,30 @@ def test_histc_lowp(self, device, dtype):
31163116
actual)
31173117
self.assertEqual(actual.dtype, dtype)
31183118

3119+
@dtypes(torch.uint8, torch.int8, torch.int, torch.long, torch.float, torch.double)
3120+
def test_histc_min_max_errors(self, device, dtype):
3121+
with self.assertRaisesRegex(RuntimeError, "max must be larger than min"):
3122+
torch.histc(torch.tensor([1., 2., 3.], dtype=dtype, device=device), bins=4, min=5, max=1)
3123+
3124+
@dtypes(torch.float, torch.double)
3125+
def test_histc_min_max_corner_cases(self, device, dtype):
3126+
actual = torch.histc(
3127+
torch.tensor([1., 2, 1], dtype=dtype, device=device),
3128+
bins=4, min=5, max=5)
3129+
self.assertEqual(
3130+
torch.tensor([2, 0, 0, 1], dtype=dtype, device=device),
3131+
actual)
3132+
3133+
@onlyCUDA
3134+
@dtypes(torch.uint8, torch.int8, torch.int, torch.long)
3135+
def test_histc_min_max_corner_cases_cuda(self, device, dtype):
3136+
actual = torch.histc(
3137+
torch.tensor([1., 2, 1], dtype=dtype, device=device),
3138+
bins=4, min=5, max=5)
3139+
self.assertEqual(
3140+
torch.tensor([2, 0, 0, 1], dtype=dtype, device=device),
3141+
actual)
3142+
31193143
"""
31203144
Runs torch.histogram and numpy.histogram on the specified input parameters
31213145
and asserts that their output is equal.

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19375,7 +19375,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1937519375
)),
1937619376
OpInfo('histc',
1937719377
dtypes=floating_types_and(torch.bfloat16, torch.float16),
19378-
dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64),
19378+
dtypesIfCUDA=floating_types_and(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64),
1937919379
sample_inputs_func=sample_inputs_histc,
1938019380
supports_out=True,
1938119381
supports_autograd=False,

0 commit comments

Comments
 (0)