Skip to content

Commit 2eba313

Browse files
committed
fixed spspmm for cpu
1 parent 57852a6 commit 2eba313

File tree

1 file changed

+24
-48
lines changed

1 file changed

+24
-48
lines changed

csrc/cpu/spspmm_cpu.cpp

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -48,80 +48,56 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
4848
auto rowptrB_data = rowptrB.data_ptr<int64_t>();
4949
auto colB_data = colB.data_ptr<int64_t>();
5050

51-
// Pass 1: Compute CSR row pointer.
5251
auto rowptrC = torch::empty_like(rowptrA);
5352
auto rowptrC_data = rowptrC.data_ptr<int64_t>();
5453
rowptrC_data[0] = 0;
5554

56-
std::vector<int64_t> mask(K, -1);
57-
int64_t nnz = 0, row_nnz, rowA_start, rowA_end, rowB_start, rowB_end, cA, cB;
58-
for (auto n = 0; n < rowptrA.numel() - 1; n++) {
59-
row_nnz = 0;
60-
61-
for (auto eA = rowptrA_data[n]; eA < rowptrA_data[n + 1]; eA++) {
62-
cA = colA_data[eA];
63-
for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
64-
cB = colB_data[eB];
65-
if (mask[cB] != n) {
66-
mask[cB] = n;
67-
row_nnz++;
68-
}
69-
}
70-
}
71-
72-
nnz += row_nnz;
73-
rowptrC_data[n + 1] = nnz;
74-
}
75-
76-
// Pass 2: Compute CSR entries.
77-
auto colC = torch::empty(nnz, rowptrC.options());
78-
auto colC_data = colC.data_ptr<int64_t>();
79-
55+
torch::Tensor colC;
8056
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
81-
if (optional_valueA.has_value())
82-
optional_valueC = torch::empty(nnz, optional_valueA.value().options());
8357

8458
AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
85-
AT_DISPATCH_HAS_VALUE(optional_valueC, [&] {
86-
scalar_t *valA_data = nullptr, *valB_data = nullptr, *valC_data = nullptr;
59+
AT_DISPATCH_HAS_VALUE(optional_valueA, [&] {
60+
scalar_t *valA_data = nullptr, *valB_data = nullptr;
8761
if (HAS_VALUE) {
8862
valA_data = optional_valueA.value().data_ptr<scalar_t>();
8963
valB_data = optional_valueB.value().data_ptr<scalar_t>();
90-
valC_data = optional_valueC.value().data_ptr<scalar_t>();
9164
}
92-
scalar_t valA;
9365

94-
rowA_start = 0, nnz = 0;
95-
std::vector<scalar_t> vals(K, 0);
96-
for (auto n = 1; n < rowptrA.numel(); n++) {
97-
rowA_end = rowptrA_data[n];
66+
int64_t nnz = 0, cA, cB;
67+
std::vector<scalar_t> tmp_vals(K, 0);
68+
std::vector<int64_t> cols;
69+
std::vector<scalar_t> vals;
9870

99-
for (auto eA = rowA_start; eA < rowA_end; eA++) {
71+
for (auto rA = 0; rA < rowptrA.numel() - 1; rA++) {
72+
for (auto eA = rowptrA_data[rA]; eA < rowptrA_data[rA + 1]; eA++) {
10073
cA = colA_data[eA];
101-
if (HAS_VALUE)
102-
valA = valA_data[eA];
103-
104-
rowB_start = rowptrB_data[cA], rowB_end = rowptrB_data[cA + 1];
105-
for (auto eB = rowB_start; eB < rowB_end; eB++) {
74+
for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
10675
cB = colB_data[eB];
76+
10777
if (HAS_VALUE)
108-
vals[cB] += valA * valB_data[eB];
78+
tmp_vals[cB] += valA_data[eA] * valB_data[eB];
10979
else
110-
vals[cB] += 1;
80+
tmp_vals[cB]++;
11181
}
11282
}
11383

11484
for (auto k = 0; k < K; k++) {
115-
if (vals[k] != 0) {
116-
colC_data[nnz] = k;
85+
if (tmp_vals[k] != 0) {
86+
cols.push_back(k);
11787
if (HAS_VALUE)
118-
valC_data[nnz] = vals[k];
88+
vals.push_back(tmp_vals[k]);
11989
nnz++;
12090
}
121-
vals[k] = (scalar_t)0;
91+
tmp_vals[k] = (scalar_t)0;
12292
}
93+
rowptrC_data[rA + 1] = nnz;
94+
}
12395

124-
rowA_start = rowA_end;
96+
colC = torch::from_blob(cols.data(), {nnz}, colA.options()).clone();
97+
if (HAS_VALUE) {
98+
optional_valueC = torch::from_blob(vals.data(), {nnz},
99+
optional_valueA.value().options());
100+
optional_valueC = optional_valueC.value().clone();
125101
}
126102
});
127103
});

0 commit comments

Comments
 (0)