Skip to content

Commit fff381c

Browse files
committed
added saint extract_adj method
1 parent 92b1e63 commit fff381c

File tree

6 files changed

+115
-1
lines changed

6 files changed

+115
-1
lines changed

csrc/cpu/saint_cpu.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "saint_cpu.h"
2+
3+
#include "utils.h"
4+
5+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
6+
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
7+
torch::Tensor col) {
8+
CHECK_CPU(idx);
9+
CHECK_CPU(rowptr);
10+
CHECK_CPU(col);
11+
12+
CHECK_INPUT(idx.dim() == 1);
13+
CHECK_INPUT(rowptr.dim() == 1);
14+
CHECK_INPUT(col.dim() == 1);
15+
16+
auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options());
17+
assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options()));
18+
19+
auto idx_data = idx.data_ptr<int64_t>();
20+
auto rowptr_data = rowptr.data_ptr<int64_t>();
21+
auto col_data = col.data_ptr<int64_t>();
22+
auto assoc_data = assoc.data_ptr<int64_t>();
23+
24+
std::vector<int64_t> rows, cols, indices;
25+
26+
int64_t v, w, w_new, row_start, row_end;
27+
for (int64_t v_new = 0; v_new < idx.size(0); v_new++) {
28+
v = idx_data[v_new];
29+
row_start = rowptr_data[v];
30+
row_end = rowptr_data[v + 1];
31+
32+
for (int64_t j = row_start; j < row_end; j++) {
33+
w = col_data[j];
34+
w_new = assoc_data[w];
35+
if (w_new > -1) {
36+
rows.push_back(v_new);
37+
cols.push_back(w_new);
38+
indices.push_back(j);
39+
}
40+
}
41+
}
42+
43+
int64_t length = rows.size();
44+
row = torch::from_blob(rows.data(), {length}, row.options()).clone();
45+
col = torch::from_blob(cols.data(), {length}, row.options()).clone();
46+
idx = torch::from_blob(indices.data(), {length}, row.options()).clone();
47+
48+
return std::make_tuple(row, col, idx);
49+
}

csrc/cpu/saint_cpu.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
6+
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
7+
torch::Tensor col);

csrc/saint.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <Python.h>
2+
#include <torch/script.h>
3+
4+
#include "cpu/saint_cpu.h"
5+
6+
#ifdef _WIN32
7+
PyMODINIT_FUNC PyInit__saint(void) { return NULL; }
8+
#endif
9+
10+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
11+
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
12+
torch::Tensor col) {
13+
if (idx.device().is_cuda()) {
14+
#ifdef WITH_CUDA
15+
AT_ERROR("No CUDA version supported");
16+
#else
17+
AT_ERROR("Not compiled with CUDA support");
18+
#endif
19+
} else {
20+
return subgraph_cpu(idx, rowptr, row, col);
21+
}
22+
}
23+
24+
static auto registry =
25+
torch::RegisterOperators().op("torch_sparse::saint_subgraph", &subgraph);

test/test_saint.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
import pytest
22
import torch
33
from torch_sparse.tensor import SparseTensor
4+
from torch_sparse.saint import subgraph
45

56
from .utils import devices
67

78

9+
@pytest.mark.parametrize('device', devices)
10+
def test_subgraph(device):
11+
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
12+
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
13+
adj = SparseTensor(row=row, col=col).to(device)
14+
node_idx = torch.tensor([0, 1, 2])
15+
16+
adj, edge_index = subgraph(adj, node_idx)
17+
18+
819
@pytest.mark.parametrize('device', devices)
920
def test_sample_node(device):
1021
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
try:
1010
for library in [
1111
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
12-
'_rw'
12+
'_rw', '_saint'
1313
]:
1414
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
1515
library, [osp.dirname(__file__)]).origin)

torch_sparse/saint.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@
66
from torch_sparse.tensor import SparseTensor
77

88

9+
def subgraph(src: SparseTensor,
10+
node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]:
11+
row, col, value = src.coo()
12+
rowptr = src.storage.rowptr()
13+
14+
data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
15+
row, col, edge_index = data
16+
17+
if value is not None:
18+
value = value[edge_index]
19+
20+
out = SparseTensor(
21+
row=row,
22+
rowptr=None,
23+
col=col,
24+
value=value,
25+
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
26+
is_sorted=True)
27+
28+
return out, edge_index
29+
30+
931
def sample_node(src: SparseTensor,
1032
num_nodes: int) -> Tuple[SparseTensor, torch.Tensor]:
1133
row, col, _ = src.coo()

0 commit comments

Comments
 (0)