Skip to content

Commit 88dd792

Browse files
committed
fix zero element tensors
1 parent bf1f101 commit 88dd792

File tree

9 files changed

+132
-51
lines changed

9 files changed

+132
-51
lines changed

csrc/cpu/scatter_cpu.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
2929
auto sizes = src.sizes().vec();
3030
if (dim_size.has_value())
3131
sizes[dim] = dim_size.value();
32+
else if (index.numel() == 0)
33+
sizes[dim] = 0;
3234
else
3335
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
3436
out = torch::empty(sizes, src.options());
@@ -41,6 +43,9 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
4143
arg_out_data = arg_out.value().data_ptr<int64_t>();
4244
}
4345

46+
if (index.numel() == 0)
47+
return std::make_tuple(out, arg_out);
48+
4449
auto B = 1;
4550
for (auto i = 0; i < dim; i++)
4651
B *= src.size(i);

csrc/cpu/segment_coo_cpu.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
3434
sizes = src.sizes().vec();
3535
if (dim_size.has_value())
3636
sizes[dim] = dim_size.value();
37+
else if (index.numel() == 0)
38+
sizes[dim] = 0;
3739
else
3840
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
3941
out = torch::empty(sizes, src.options());
@@ -44,15 +46,15 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
4446
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
4547
arg_out = torch::full_like(out, src.size(dim), index.options());
4648
arg_out_data = arg_out.value().data_ptr<int64_t>();
47-
}
48-
49-
torch::optional<torch::Tensor> count = torch::nullopt;
50-
if (reduce2REDUCE.at(reduce) == MEAN) {
49+
} else if (reduce2REDUCE.at(reduce) == MEAN) {
5150
auto sizes = index.sizes().vec();
5251
sizes[dim] = out.size(dim);
53-
count = torch::zeros(sizes, out.options());
52+
arg_out = torch::zeros(sizes, out.options());
5453
}
5554

55+
if (index.numel() == 0)
56+
return std::make_tuple(out, arg_out);
57+
5658
auto B = index.numel() / src.size(dim);
5759
auto E = src.size(dim);
5860
auto K = src.numel() / index.numel();
@@ -72,7 +74,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
7274
if (!optional_out.has_value())
7375
out.fill_(Reducer<scalar_t, REDUCE>::init());
7476
if (REDUCE == MEAN)
75-
count_data = count.value().data_ptr<scalar_t>();
77+
count_data = arg_out.value().data_ptr<scalar_t>();
7678

7779
for (auto b = 0; b < B; b++) {
7880
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
@@ -122,7 +124,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
122124
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
123125

124126
if (REDUCE == MEAN)
125-
arg_out = count;
127+
arg_out.value().clamp_(1);
126128
});
127129
});
128130

@@ -156,6 +158,9 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
156158
out = torch::empty(sizes, src.options());
157159
}
158160

161+
if (index.numel() == 0)
162+
return out;
163+
159164
auto B = index.numel() / out.size(dim);
160165
auto E = index.size(dim);
161166
auto K = out.numel() / index.numel();

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
3030
for (auto i = 0; i < out.dim(); i++)
3131
if (i != dim)
3232
CHECK_INPUT(src.size(i) == out.size(i));
33-
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
33+
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
3434
} else {
3535
sizes = src.sizes().vec();
36-
sizes[dim] = indptr.size(dim) - 1;
36+
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
3737
out = torch::empty(sizes, src.options());
3838
}
3939

@@ -44,6 +44,9 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
4444
arg_out_data = arg_out.value().data_ptr<int64_t>();
4545
}
4646

47+
if (src.numel() == 0)
48+
return std::make_tuple(out, arg_out);
49+
4750
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
4851
auto K = out.numel() / N;
4952
auto E = src.size(dim);
@@ -98,7 +101,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
98101
indptr = indptr.expand(sizes);
99102

100103
auto dim = indptr.dim() - 1;
101-
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
104+
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
102105

103106
src = src.contiguous();
104107

@@ -110,10 +113,16 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
110113
CHECK_INPUT(src.size(i) == out.size(i));
111114
} else {
112115
auto sizes = src.sizes().vec();
113-
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
116+
if (src.numel() > 0)
117+
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
118+
else
119+
sizes[dim] = 0;
114120
out = torch::empty(sizes, src.options());
115121
}
116122

123+
if (src.numel() == 0)
124+
return out;
125+
117126
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
118127
auto K = src.numel() / N;
119128
auto E = out.size(dim);

csrc/cuda/scatter_cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
8181
auto sizes = src.sizes().vec();
8282
if (dim_size.has_value())
8383
sizes[dim] = dim_size.value();
84+
else if (index.numel() == 0)
85+
sizes[dim] = 0;
8486
else {
8587
auto d_size = index.max().data_ptr<int64_t>();
8688
auto h_size = (int64_t *)malloc(sizeof(int64_t));
@@ -97,6 +99,9 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
9799
arg_out_data = arg_out.value().data_ptr<int64_t>();
98100
}
99101

102+
if (index.numel() == 0)
103+
return std::make_tuple(out, arg_out);
104+
100105
auto B = 1;
101106
for (auto i = 0; i < dim; i++)
102107
B *= src.size(i);

csrc/cuda/segment_coo_cuda.cu

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
181181
sizes = src.sizes().vec();
182182
if (dim_size.has_value())
183183
sizes[dim] = dim_size.value();
184+
else if (index.numel() == 0)
185+
sizes[dim] = 0;
184186
else {
185187
auto d_size = index.max().data_ptr<int64_t>();
186188
auto h_size = (int64_t *)malloc(sizeof(int64_t));
@@ -195,8 +197,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
195197
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
196198
arg_out = torch::full_like(out, src.size(dim), index.options());
197199
arg_out_data = arg_out.value().data_ptr<int64_t>();
200+
} else if (reduce2REDUCE.at(reduce) == MEAN) {
201+
auto sizes = index.sizes().vec();
202+
sizes[dim] = out.size(dim);
203+
arg_out = torch::zeros(sizes, out.options());
198204
}
199205

206+
if (index.numel() == 0)
207+
return std::make_tuple(out, arg_out);
208+
200209
auto E = index.numel();
201210
auto E_2 = index.size(dim);
202211
auto E_1 = index.numel() / E_2;
@@ -254,17 +263,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
254263
}
255264
256265
if (REDUCE == MEAN) {
257-
auto sizes = index.sizes().vec();
258-
sizes[dim] = out.size(dim);
259-
auto count = torch::zeros(sizes, out.options());
260-
auto count_data = count.data_ptr<scalar_t>();
266+
auto count_data = arg_out.value().data_ptr<scalar_t>();
261267
segment_coo_kernel<scalar_t, SUM, false>
262268
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
263269
count_data, E, N);
264-
arg_out = count;
270+
arg_out.value().clamp_(1);
271+
auto count = arg_out.value();
265272
for (int i = dim + 1; i < out.dim(); i++)
266273
count = count.unsqueeze(-1);
267-
out.div_(count.clamp_(1));
274+
out.div_(count);
268275
}
269276
});
270277
});
@@ -346,6 +353,9 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
346353
out = torch::empty(sizes, src.options());
347354
}
348355
356+
if (index.numel() == 0)
357+
return out;
358+
349359
auto E = index.numel();
350360
auto K = out.numel() / E;
351361
auto N = src.size(dim);

csrc/cuda/segment_csr_cuda.cu

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
121121
for (int i = 0; i < out.dim(); i++)
122122
if (i != dim)
123123
CHECK_INPUT(src.size(i) == out.size(i));
124-
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
124+
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
125125
} else {
126126
sizes = src.sizes().vec();
127-
sizes[dim] = indptr.size(dim) - 1;
127+
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
128128
out = torch::empty(sizes, src.options());
129129
}
130130

@@ -135,6 +135,9 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
135135
arg_out_data = arg_out.value().data_ptr<int64_t>();
136136
}
137137

138+
if (src.numel() == 0)
139+
return std::make_tuple(out, arg_out);
140+
138141
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
139142
auto K = out.numel() / N;
140143
auto E = src.size(dim);
@@ -226,7 +229,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
226229
indptr = indptr.expand(sizes);
227230

228231
auto dim = indptr.dim() - 1;
229-
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
232+
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
230233

231234
src = src.contiguous();
232235

@@ -237,14 +240,20 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
237240
if (i != dim)
238241
CHECK_INPUT(src.size(i) == out.size(i));
239242
} else {
240-
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
241-
auto h_size = (int64_t *)malloc(sizeof(int64_t));
242-
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
243243
auto sizes = src.sizes().vec();
244-
sizes[dim] = *h_size;
244+
if (src.numel() > 0) {
245+
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
246+
auto h_size = (int64_t *)malloc(sizeof(int64_t));
247+
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
248+
sizes[dim] = *h_size;
249+
} else
250+
sizes[dim] = 0;
245251
out = torch::empty(sizes, src.options());
246252
}
247253

254+
if (src.numel() == 0)
255+
return out;
256+
248257
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
249258
auto K = src.numel() / N;
250259
auto E = out.size(dim);

csrc/segment_csr.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,16 @@ class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
8282
auto indptr = saved[0];
8383
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
8484
auto grad_in = torch::empty(src_shape, grad_out.options());
85-
gather_csr_fw(grad_out, indptr, grad_in);
86-
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
87-
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
88-
auto count = (indptr2 - indptr1).to(grad_in.options());
89-
count = gather_csr_fw(count, indptr, torch::nullopt);
90-
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
91-
count = count.unsqueeze(-1);
92-
grad_in.div_(count);
85+
if (grad_in.numel() > 0) {
86+
gather_csr_fw(grad_out, indptr, grad_in);
87+
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
88+
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
89+
auto count = (indptr2 - indptr1).to(grad_in.options());
90+
count = gather_csr_fw(count, indptr, torch::nullopt);
91+
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
92+
count = count.unsqueeze(-1);
93+
grad_in.div_(count);
94+
}
9395
return {grad_in, Variable(), Variable()};
9496
}
9597
};

test/test_zero_tensors.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,37 @@
1+
from itertools import product
2+
3+
import pytest
14
import torch
2-
from torch_scatter import scatter
5+
from torch_scatter import scatter, segment_coo, gather_coo
6+
from torch_scatter import segment_csr, gather_csr
7+
8+
from .utils import reductions, tensor, grad_dtypes, devices
9+
10+
11+
@pytest.mark.parametrize('reduce,dtype,device',
12+
product(reductions, grad_dtypes, devices))
13+
def test_zero_elements(reduce, dtype, device):
14+
x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device,
15+
requires_grad=True)
16+
index = tensor([], torch.long, device)
17+
indptr = tensor([], torch.long, device)
18+
19+
out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
20+
out.backward(torch.randn_like(out))
21+
assert out.size() == (0, 0, 0, 16)
22+
23+
out = segment_coo(x, index, dim_size=0, reduce=reduce)
24+
out.backward(torch.randn_like(out))
25+
assert out.size() == (0, 0, 0, 16)
326

27+
out = gather_coo(x, index)
28+
out.backward(torch.randn_like(out))
29+
assert out.size() == (0, 0, 0, 16)
430

5-
def test_zero_elements():
6-
x = torch.randn(0, 16)
7-
index = torch.tensor([]).view(0, 16)
8-
print(x)
9-
print(index)
31+
out = segment_csr(x, indptr, reduce=reduce)
32+
out.backward(torch.randn_like(out))
33+
assert out.size() == (0, 0, 0, 16)
1034

11-
scatter(x, index, dim=0, dim_size=0, reduce="add")
35+
out = gather_csr(x, indptr)
36+
out.backward(torch.randn_like(out))
37+
assert out.size() == (0, 0, 0, 16)

torch_scatter/scatter.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,13 @@
1212
except OSError:
1313
warnings.warn('Failed to load `scatter` binaries.')
1414

15-
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
16-
out: Optional[torch.Tensor],
17-
dim_size: Optional[int]) -> torch.Tensor:
18-
raise ImportError
19-
return src
20-
2115
def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
2216
dim: int, out: Optional[torch.Tensor],
2317
dim_size: Optional[int]
2418
) -> Tuple[torch.Tensor, torch.Tensor]:
2519
raise ImportError
2620
return src, index
2721

28-
torch.ops.torch_scatter.scatter_mean = scatter_placeholder
2922
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
3023
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
3124

@@ -37,11 +30,13 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3730
index = broadcast(index, src, dim)
3831
if out is None:
3932
size = src.size()
40-
if dim_size is None:
41-
size[dim] = int(index.max()) + 1
42-
else:
33+
if dim_size is not None:
4334
size[dim] = dim_size
44-
out = src.new_zeros(size)
35+
elif index.numel() == 0:
36+
size[dim] = 0
37+
else:
38+
size[dim] = int(index.max()) + 1
39+
out = torch.zeros(size, dtype=src.dtype, device=src.device)
4540
return out.scatter_add_(dim, index, src)
4641
else:
4742
return out.scatter_add_(dim, index, src)
@@ -58,7 +53,22 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5853
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5954
out: Optional[torch.Tensor] = None,
6055
dim_size: Optional[int] = None) -> torch.Tensor:
61-
return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size)
56+
57+
out = scatter_sum(src, index, dim, out, dim_size)
58+
dim_size = out.size(dim)
59+
60+
index_dim = dim
61+
if index_dim < 0:
62+
index_dim = index_dim + src.dim()
63+
if index.dim() <= dim:
64+
index_dim = index.dim() - 1
65+
66+
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
67+
count = scatter_sum(ones, index, index_dim, None, dim_size)
68+
count.clamp_(1)
69+
count = broadcast(count, out, dim)
70+
out.div_(count)
71+
return out
6272

6373

6474
@torch.jit.script

0 commit comments

Comments
 (0)