|
| 1 | +#include "saint_cpu.h" |
| 2 | + |
| 3 | +#include "utils.h" |
| 4 | + |
| 5 | +std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> |
| 6 | +subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row, |
| 7 | + torch::Tensor col) { |
| 8 | + CHECK_CPU(idx); |
| 9 | + CHECK_CPU(rowptr); |
| 10 | + CHECK_CPU(col); |
| 11 | + |
| 12 | + CHECK_INPUT(idx.dim() == 1); |
| 13 | + CHECK_INPUT(rowptr.dim() == 1); |
| 14 | + CHECK_INPUT(col.dim() == 1); |
| 15 | + |
| 16 | + auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options()); |
| 17 | + assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options())); |
| 18 | + |
| 19 | + auto idx_data = idx.data_ptr<int64_t>(); |
| 20 | + auto rowptr_data = rowptr.data_ptr<int64_t>(); |
| 21 | + auto col_data = col.data_ptr<int64_t>(); |
| 22 | + auto assoc_data = assoc.data_ptr<int64_t>(); |
| 23 | + |
| 24 | + std::vector<int64_t> rows, cols, indices; |
| 25 | + |
| 26 | + int64_t v, w, w_new, row_start, row_end; |
| 27 | + for (int64_t v_new = 0; v_new < idx.size(0); v_new++) { |
| 28 | + v = idx_data[v_new]; |
| 29 | + row_start = rowptr_data[v]; |
| 30 | + row_end = rowptr_data[v + 1]; |
| 31 | + |
| 32 | + for (int64_t j = row_start; j < row_end; j++) { |
| 33 | + w = col_data[j]; |
| 34 | + w_new = assoc_data[w]; |
| 35 | + if (w_new > -1) { |
| 36 | + rows.push_back(v_new); |
| 37 | + cols.push_back(w_new); |
| 38 | + indices.push_back(j); |
| 39 | + } |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + int64_t length = rows.size(); |
| 44 | + row = torch::from_blob(rows.data(), {length}, row.options()).clone(); |
| 45 | + col = torch::from_blob(cols.data(), {length}, row.options()).clone(); |
| 46 | + idx = torch::from_blob(indices.data(), {length}, row.options()).clone(); |
| 47 | + |
| 48 | + return std::make_tuple(row, col, idx); |
| 49 | +} |
0 commit comments