Skip to content

Commit bb87ec6

Browse files
committed
coo cpu implementation
1 parent 0c887ff commit bb87ec6

File tree

12 files changed

+585
-48
lines changed

12 files changed

+585
-48
lines changed

cpu/segment_coo_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
166166

167167
if (e < E - 1) {
168168
next_idx = index_info.data[offset + (e + 1) * stride];
169-
CHECK_INPUT(idx < E && idx <= next_idx);
169+
CHECK_INPUT(idx <= next_idx);
170170

171171
if (idx != next_idx) {
172172
idx = next_idx;

csrc/cpu/segment_coo_cpu.cpp

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#include "segment_coo_cpu.h"
2+
3+
#include "index_info.h"
4+
#include "reducer.h"
5+
#include "utils.h"
6+
7+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
8+
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
9+
torch::optional<torch::Tensor> optional_out,
10+
torch::optional<int64_t> dim_size, std::string reduce) {
11+
CHECK_CPU(src);
12+
CHECK_CPU(index);
13+
if (optional_out.has_value())
14+
CHECK_CPU(optional_out.value());
15+
16+
CHECK_INPUT(src.dim() >= index.dim());
17+
18+
auto sizes = index.sizes().vec();
19+
for (int i = 0; i < index.dim(); i++)
20+
sizes[i] = src.size(i);
21+
index = index.expand(sizes);
22+
23+
auto dim = index.dim() - 1;
24+
25+
src = src.contiguous();
26+
27+
torch::Tensor out;
28+
if (optional_out.has_value()) {
29+
out = optional_out.value().contiguous();
30+
for (int i = 0; i < out.dim(); i++)
31+
if (i != dim)
32+
CHECK_INPUT(src.size(i) == out.size(i));
33+
} else {
34+
sizes = src.sizes().vec();
35+
if (dim_size.has_value())
36+
sizes[dim] = dim_size.value();
37+
else
38+
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
39+
out = torch::empty(sizes, src.options());
40+
}
41+
42+
torch::optional<torch::Tensor> arg_out = torch::nullopt;
43+
int64_t *arg_out_data = nullptr;
44+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
45+
arg_out = torch::full_like(out, src.size(dim), index.options());
46+
arg_out_data = arg_out.value().data_ptr<int64_t>();
47+
}
48+
49+
torch::optional<torch::Tensor> count = torch::nullopt;
50+
if (reduce2REDUCE.at(reduce) == MEAN) {
51+
auto sizes = index.sizes().vec();
52+
sizes[dim] = out.size(dim);
53+
count = torch::zeros(sizes, out.options());
54+
}
55+
56+
auto B = index.numel() / src.size(dim);
57+
auto E = src.size(dim);
58+
auto K = src.numel() / index.numel();
59+
auto N = out.size(dim);
60+
61+
auto index_info = getTensorInfo<int64_t>(index);
62+
auto stride = index_info.strides[index_info.dims - 1];
63+
std::vector<int64_t> args(K);
64+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
65+
auto src_data = src.data_ptr<scalar_t>();
66+
auto out_data = out.data_ptr<scalar_t>();
67+
scalar_t *count_data = nullptr;
68+
69+
std::vector<scalar_t> vals(K);
70+
int64_t idx, next_idx, row_start;
71+
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
72+
if (!optional_out.has_value())
73+
out.fill_(Reducer<scalar_t, REDUCE>::init());
74+
if (REDUCE == MEAN)
75+
count_data = count.value().data_ptr<scalar_t>();
76+
77+
for (auto b = 0; b < B; b++) {
78+
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
79+
idx = index_info.data[offset];
80+
81+
for (auto k = 0; k < K; k++)
82+
vals[k] = out_data[b * N * K + k];
83+
84+
row_start = 0;
85+
for (auto e = 0; e < E; e++) {
86+
87+
for (auto k = 0; k < K; k++)
88+
Reducer<scalar_t, REDUCE>::update(
89+
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
90+
91+
if (e == E - 1) {
92+
for (auto k = 0; k < K; k++)
93+
Reducer<scalar_t, REDUCE>::write(
94+
out_data + b * N * K + idx * K + k, vals[k],
95+
arg_out_data + b * N * K + idx * K + k, args[k],
96+
e + 1 - row_start);
97+
if (REDUCE == MEAN)
98+
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
99+
} else {
100+
next_idx = index_info.data[offset + (e + 1) * stride];
101+
assert(idx <= next_idx);
102+
103+
if (idx != next_idx) {
104+
for (auto k = 0; k < K; k++) {
105+
Reducer<scalar_t, REDUCE>::write(
106+
out_data + b * N * K + idx * K + k, vals[k],
107+
arg_out_data + b * N * K + idx * K + k, args[k],
108+
e + 1 - row_start);
109+
110+
vals[k] = out_data[b * N * K + next_idx * K + k];
111+
}
112+
if (REDUCE == MEAN)
113+
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
114+
row_start = e + 1;
115+
}
116+
117+
idx = next_idx;
118+
}
119+
}
120+
}
121+
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
122+
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
123+
124+
if (REDUCE == MEAN)
125+
arg_out = count;
126+
});
127+
});
128+
129+
return std::make_tuple(out, arg_out);
130+
}
131+
132+
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
133+
torch::optional<torch::Tensor> optional_out) {
134+
CHECK_CPU(src);
135+
CHECK_CPU(index);
136+
if (optional_out.has_value())
137+
CHECK_CPU(optional_out.value());
138+
139+
CHECK_INPUT(src.dim() >= index.dim());
140+
for (auto i = 0; i < index.dim() - 1; i++)
141+
CHECK_INPUT(src.size(i) == index.size(i));
142+
143+
auto dim = index.dim() - 1;
144+
145+
src = src.contiguous();
146+
147+
torch::Tensor out;
148+
if (optional_out.has_value()) {
149+
out = optional_out.value().contiguous();
150+
for (auto i = 0; i < src.dim(); i++)
151+
if (i != dim)
152+
CHECK_INPUT(src.size(i) == out.size(i));
153+
} else {
154+
auto sizes = src.sizes().vec();
155+
sizes[dim] = index.size(dim);
156+
out = torch::empty(sizes, src.options());
157+
}
158+
159+
auto B = index.numel() / out.size(dim);
160+
auto E = index.size(dim);
161+
auto K = out.numel() / index.numel();
162+
auto N = src.size(dim);
163+
164+
auto index_info = getTensorInfo<int64_t>(index);
165+
auto stride = index_info.strides[index_info.dims - 1];
166+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
167+
auto src_data = src.data_ptr<scalar_t>();
168+
auto out_data = out.data_ptr<scalar_t>();
169+
170+
std::vector<scalar_t> vals(K);
171+
int64_t idx, next_idx;
172+
for (auto b = 0; b < B; b++) {
173+
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
174+
idx = index_info.data[offset];
175+
176+
for (auto k = 0; k < K; k++)
177+
vals[k] = src_data[b * N * K + idx * K + k];
178+
179+
for (auto e = 0; e < E; e++) {
180+
for (auto k = 0; k < K; k++)
181+
out_data[b * E * K + e * K + k] = vals[k];
182+
183+
if (e < E - 1) {
184+
next_idx = index_info.data[offset + (e + 1) * stride];
185+
CHECK_INPUT(idx <= next_idx);
186+
187+
if (idx != next_idx) {
188+
idx = next_idx;
189+
for (auto k = 0; k < K; k++)
190+
vals[k] = src_data[b * N * K + idx * K + k];
191+
}
192+
}
193+
}
194+
}
195+
});
196+
197+
return out;
198+
}

csrc/cpu/segment_coo_cpu.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
6+
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
7+
torch::optional<torch::Tensor> optional_out,
8+
torch::optional<int64_t> dim_size, std::string reduce);
9+
10+
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
11+
torch::optional<torch::Tensor> optional_out);

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
6767
for (auto k = 0; k < K; k++)
6868
vals[k] = Reducer<scalar_t, REDUCE>::init();
6969

70-
for (auto e = row_start; e < row_end; e++) {
71-
CHECK_INPUT(e < E);
70+
for (auto e = row_start; e < row_end; e++)
7271
for (auto k = 0; k < K; k++)
7372
Reducer<scalar_t, REDUCE>::update(
7473
&vals[k], src_data[offset + e * K + k], &args[k], e);
75-
}
7674

7775
for (auto k = 0; k < K; k++)
7876
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],

csrc/cuda/segment_coo_cuda.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "segment_coo_cuda.h"
2+
3+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
4+
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
5+
torch::optional<torch::Tensor> optional_out,
6+
torch::optional<int64_t> dim_size, std::string reduce) {
7+
return std::make_tuple(src, optional_out);
8+
}
9+
10+
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
11+
torch::optional<torch::Tensor> optional_out) {
12+
return src;
13+
}

csrc/cuda/segment_coo_cuda.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
6+
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
7+
torch::optional<torch::Tensor> optional_out,
8+
torch::optional<int64_t> dim_size, std::string reduce);
9+
10+
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
11+
torch::optional<torch::Tensor> optional_out);

0 commit comments

Comments
 (0)