Skip to content

Commit 0c887ff

Browse files
committed
segment/gather csr done
1 parent 26a9e98 commit 0c887ff

File tree

17 files changed

+1327
-118
lines changed

17 files changed

+1327
-118
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2019 Matthias Fey <[email protected]>
1+
Copyright (c) 2020 Matthias Fey <[email protected]>
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

csrc/cpu/index_info.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
#define MAX_TENSORINFO_DIMS 25
6+
7+
template <typename scalar_t> struct TensorInfo {
8+
TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
9+
int st[MAX_TENSORINFO_DIMS]) {
10+
data = p;
11+
dims = dim;
12+
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);
13+
14+
for (int i = 0; i < dim; ++i) {
15+
sizes[i] = sz[i];
16+
strides[i] = st[i];
17+
}
18+
}
19+
20+
scalar_t *data;
21+
int dims;
22+
int sizes[MAX_TENSORINFO_DIMS];
23+
int strides[MAX_TENSORINFO_DIMS];
24+
};
25+
26+
template <typename scalar_t>
27+
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
28+
int sizes[MAX_TENSORINFO_DIMS];
29+
int strides[MAX_TENSORINFO_DIMS];
30+
31+
int dims = tensor.dim();
32+
for (int i = 0; i < dims; ++i) {
33+
sizes[i] = tensor.size(i);
34+
strides[i] = tensor.stride(i);
35+
}
36+
37+
return TensorInfo<scalar_t>(tensor.data_ptr<scalar_t>(), dims, sizes,
38+
strides);
39+
}
40+
41+
template <typename scalar_t> struct IndexToOffset {
42+
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
43+
int offset = 0;
44+
for (int i = info.dims - 1; i >= 0; --i) {
45+
offset += (idx % info.sizes[i]) * info.strides[i];
46+
idx /= info.sizes[i];
47+
}
48+
return offset;
49+
}
50+
};
51+
52+
template <typename scalar_t> struct IndexPtrToOffset {
53+
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
54+
int offset = idx % (info.sizes[info.dims - 1] - 1);
55+
offset *= info.strides[info.dims - 1];
56+
idx /= info.sizes[info.dims - 1] - 1;
57+
for (int i = info.dims - 2; i >= 0; --i) {
58+
offset += (idx % info.sizes[i]) * info.strides[i];
59+
idx /= info.sizes[i];
60+
}
61+
return offset;
62+
}
63+
};

csrc/cpu/reducer.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#pragma once
2+
3+
#include <limits>
4+
#include <map>
5+
6+
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
7+
8+
const std::map<std::string, ReductionType> reduce2REDUCE = {
9+
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
10+
{"div", DIV}, {"min", MIN}, {"max", MAX},
11+
};
12+
13+
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
14+
[&] { \
15+
switch (reduce2REDUCE.at(reduce)) { \
16+
case SUM: { \
17+
const ReductionType REDUCE = SUM; \
18+
return __VA_ARGS__(); \
19+
} \
20+
case MEAN: { \
21+
const ReductionType REDUCE = MEAN; \
22+
return __VA_ARGS__(); \
23+
} \
24+
case MUL: { \
25+
const ReductionType REDUCE = MUL; \
26+
return __VA_ARGS__(); \
27+
} \
28+
case DIV: { \
29+
const ReductionType REDUCE = DIV; \
30+
return __VA_ARGS__(); \
31+
} \
32+
case MIN: { \
33+
const ReductionType REDUCE = MIN; \
34+
return __VA_ARGS__(); \
35+
} \
36+
case MAX: { \
37+
const ReductionType REDUCE = MAX; \
38+
return __VA_ARGS__(); \
39+
} \
40+
} \
41+
}()
42+
43+
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
44+
static inline scalar_t init() {
45+
if (REDUCE == MUL || REDUCE == DIV)
46+
return (scalar_t)1;
47+
else if (REDUCE == MIN)
48+
return std::numeric_limits<scalar_t>::max();
49+
else if (REDUCE == MAX)
50+
return std::numeric_limits<scalar_t>::lowest();
51+
else
52+
return (scalar_t)0;
53+
}
54+
55+
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
56+
int64_t new_arg) {
57+
if (REDUCE == SUM || REDUCE == MEAN)
58+
*val = *val + new_val;
59+
else if (REDUCE == MUL)
60+
*val = *val * new_val;
61+
else if (REDUCE == DIV)
62+
*val = *val / new_val;
63+
else if ((REDUCE == MIN && new_val < *val) ||
64+
(REDUCE == MAX && new_val > *val)) {
65+
*val = new_val;
66+
*arg = new_arg;
67+
}
68+
}
69+
70+
static inline void write(scalar_t *address, scalar_t val,
71+
int64_t *arg_address, int64_t arg, int count) {
72+
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
73+
*address = val;
74+
else if (REDUCE == MEAN)
75+
*address = val / (count > 0 ? count : (scalar_t)1);
76+
else if (REDUCE == MIN || REDUCE == MAX) {
77+
if (count > 0) {
78+
*address = val;
79+
*arg_address = arg;
80+
} else
81+
*address = (scalar_t)0;
82+
}
83+
}
84+
};

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#include "segment_csr_cpu.h"
2+
3+
#include "index_info.h"
4+
#include "reducer.h"
5+
#include "utils.h"
6+
7+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
8+
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
9+
torch::optional<torch::Tensor> optional_out,
10+
std::string reduce) {
11+
CHECK_CPU(src);
12+
CHECK_CPU(indptr);
13+
if (optional_out.has_value())
14+
CHECK_CPU(optional_out.value());
15+
16+
CHECK_INPUT(src.dim() >= indptr.dim());
17+
18+
auto sizes = indptr.sizes().vec();
19+
for (auto i = 0; i < indptr.dim() - 1; i++)
20+
sizes[i] = src.size(i);
21+
indptr = indptr.expand(sizes);
22+
23+
auto dim = indptr.dim() - 1;
24+
25+
src = src.contiguous();
26+
27+
torch::Tensor out;
28+
if (optional_out.has_value()) {
29+
out = optional_out.value().contiguous();
30+
for (int i = 0; i < out.dim(); i++)
31+
if (i != dim)
32+
CHECK_INPUT(src.size(i) == out.size(i));
33+
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
34+
} else {
35+
sizes = src.sizes().vec();
36+
sizes[dim] = indptr.size(dim) - 1;
37+
out = torch::empty(sizes, src.options());
38+
}
39+
40+
torch::optional<torch::Tensor> arg_out = torch::nullopt;
41+
int64_t *arg_out_data = nullptr;
42+
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
43+
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
44+
arg_out_data = arg_out.value().data_ptr<int64_t>();
45+
}
46+
47+
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
48+
auto K = out.numel() / N;
49+
auto E = src.size(dim);
50+
51+
auto indptr_info = getTensorInfo<int64_t>(indptr);
52+
auto stride = indptr_info.strides[indptr_info.dims - 1];
53+
std::vector<int64_t> args(K);
54+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
55+
auto src_data = src.data_ptr<scalar_t>();
56+
auto out_data = out.data_ptr<scalar_t>();
57+
58+
std::vector<scalar_t> vals(K);
59+
int64_t row_start, row_end;
60+
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
61+
for (auto n = 0; n < N; n++) {
62+
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
63+
row_start = indptr_info.data[offset];
64+
row_end = indptr_info.data[offset + stride];
65+
66+
offset = (n / (indptr.size(-1) - 1)) * E * K;
67+
for (auto k = 0; k < K; k++)
68+
vals[k] = Reducer<scalar_t, REDUCE>::init();
69+
70+
for (auto e = row_start; e < row_end; e++) {
71+
CHECK_INPUT(e < E);
72+
for (auto k = 0; k < K; k++)
73+
Reducer<scalar_t, REDUCE>::update(
74+
&vals[k], src_data[offset + e * K + k], &args[k], e);
75+
}
76+
77+
for (auto k = 0; k < K; k++)
78+
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
79+
arg_out_data + n * K + k, args[k],
80+
row_end - row_start);
81+
}
82+
});
83+
});
84+
85+
return std::make_tuple(out, arg_out);
86+
}
87+
88+
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
89+
torch::optional<torch::Tensor> optional_out) {
90+
CHECK_CPU(src);
91+
CHECK_CPU(indptr);
92+
if (optional_out.has_value())
93+
CHECK_CPU(optional_out.value());
94+
95+
CHECK_INPUT(src.dim() >= indptr.dim());
96+
97+
auto sizes = indptr.sizes().vec();
98+
for (auto i = 0; i < indptr.dim() - 1; i++)
99+
sizes[i] = src.size(i);
100+
indptr = indptr.expand(sizes);
101+
102+
auto dim = indptr.dim() - 1;
103+
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
104+
105+
src = src.contiguous();
106+
107+
torch::Tensor out;
108+
if (optional_out.has_value()) {
109+
out = optional_out.value().contiguous();
110+
for (auto i = 0; i < out.dim(); i++)
111+
if (i != dim)
112+
CHECK_INPUT(src.size(i) == out.size(i));
113+
} else {
114+
auto sizes = src.sizes().vec();
115+
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
116+
out = torch::empty(sizes, src.options());
117+
}
118+
119+
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
120+
auto K = src.numel() / N;
121+
auto E = out.size(dim);
122+
123+
auto indptr_info = getTensorInfo<int64_t>(indptr);
124+
auto stride = indptr_info.strides[indptr_info.dims - 1];
125+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
126+
auto src_data = src.data_ptr<scalar_t>();
127+
auto out_data = out.data_ptr<scalar_t>();
128+
129+
std::vector<scalar_t> vals(K);
130+
int64_t row_start, row_end;
131+
for (int n = 0; n < N; n++) {
132+
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
133+
row_start = indptr_info.data[offset];
134+
row_end = indptr_info.data[offset + stride];
135+
136+
for (auto k = 0; k < K; k++)
137+
vals[k] = src_data[n * K + k];
138+
139+
offset = (n / (indptr.size(-1) - 1)) * E * K;
140+
for (auto e = row_start; e < row_end; e++)
141+
for (auto k = 0; k < K; k++)
142+
out_data[offset + e * K + k] = vals[k];
143+
}
144+
});
145+
146+
return out;
147+
}

csrc/cpu/segment_csr_cpu.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
6+
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
7+
torch::optional<torch::Tensor> optional_out,
8+
std::string reduce);
9+
10+
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
11+
torch::optional<torch::Tensor> optional_out);

csrc/cpu/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
6+
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")

0 commit comments

Comments
 (0)