Skip to content

Commit 5db0086

Browse files
committed
faster segment csr cpu implementation
1 parent 3994f3a commit 5db0086

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

cpu/segment.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
123123
auto src_data = src.DATA_PTR<scalar_t>();
124124
auto out_data = out.DATA_PTR<scalar_t>();
125125

126-
scalar_t val;
127-
int64_t row_start, row_end, arg;
126+
scalar_t vals[K];
127+
int64_t row_start, row_end, args[K];
128128
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
129129
for (int n = 0; n < N; n++) {
130130
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
@@ -133,13 +133,17 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
133133

134134
offset = (n / (indptr.size(-1) - 1)) * E * K;
135135
for (int k = 0; k < K; k++) {
136-
val = Reducer<scalar_t, REDUCE>::init();
137-
for (int64_t e = row_start; e < row_end; e++) {
136+
vals[k] = Reducer<scalar_t, REDUCE>::init();
137+
}
138+
for (int64_t e = row_start; e < row_end; e++) {
139+
for (int k = 0; k < K; k++) {
138140
Reducer<scalar_t, REDUCE>::update(
139-
&val, src_data[offset + e * K + k], &arg, e);
141+
&vals[k], src_data[offset + e * K + k], &args[k], e);
140142
}
141-
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, val,
142-
arg_out_data + n * K + k, arg,
143+
}
144+
for (int k = 0; k < K; k++) {
145+
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
146+
arg_out_data + n * K + k, args[k],
143147
row_end - row_start);
144148
}
145149
}

0 commit comments

Comments
 (0)