Skip to content

Commit eda4b3d

Browse files
committed
random walk
1 parent fff381c commit eda4b3d

File tree

4 files changed

+20
-93
lines changed

4 files changed

+20
-93
lines changed

test/test_saint.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,15 @@
11
import pytest
22
import torch
33
from torch_sparse.tensor import SparseTensor
4-
from torch_sparse.saint import subgraph
54

65
from .utils import devices
76

87

98
@pytest.mark.parametrize('device', devices)
10-
def test_subgraph(device):
9+
def test_saint_subgraph(device):
1110
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
1211
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
1312
adj = SparseTensor(row=row, col=col).to(device)
1413
node_idx = torch.tensor([0, 1, 2])
1514

16-
adj, edge_index = subgraph(adj, node_idx)
17-
18-
19-
@pytest.mark.parametrize('device', devices)
20-
def test_sample_node(device):
21-
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
22-
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
23-
adj = SparseTensor(row=row, col=col).to(device)
24-
25-
adj, perm = adj.sample_node(num_nodes=3)
26-
27-
28-
@pytest.mark.parametrize('device', devices)
29-
def test_sample_edge(device):
30-
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
31-
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
32-
adj = SparseTensor(row=row, col=col).to(device)
33-
34-
adj, perm = adj.sample_edge(num_edges=3)
35-
36-
37-
@pytest.mark.parametrize('device', devices)
38-
def test_sample_rw(device):
39-
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
40-
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
41-
adj = SparseTensor(row=row, col=col).to(device)
42-
43-
adj, perm = adj.sample_rw(num_root_nodes=3, walk_length=2)
15+
adj, edge_index = adj.saint_subgraph(node_idx)

torch_sparse/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@
5555
from .reduce import sum, mean, min, max # noqa
5656
from .matmul import matmul # noqa
5757
from .cat import cat, cat_diag # noqa
58+
from .rw import random_walk # noqa
5859
from .metis import partition # noqa
59-
from .saint import sample_node, sample_edge, sample_rw # noqa
60+
from .saint import saint_subgraph # noqa
6061

6162
from .convert import to_torch_sparse, from_torch_sparse # noqa
6263
from .convert import to_scipy, from_scipy # noqa
@@ -96,10 +97,9 @@
9697
'matmul',
9798
'cat',
9899
'cat_diag',
100+
'random_walk',
99101
'partition',
100-
'sample_node',
101-
'sample_edge',
102-
'sample_rw',
102+
'saint_subgraph',
103103
'to_torch_sparse',
104104
'from_torch_sparse',
105105
'to_scipy',

torch_sparse/rw.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from torch_sparse.tensor import SparseTensor
3+
4+
5+
def random_walk(src: SparseTensor, start: torch.Tensor,
6+
walk_length: int) -> torch.Tensor:
7+
rowptr, col, _ = src.csr()
8+
return torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
9+
10+
11+
SparseTensor.random_walk = random_walk

torch_sparse/saint.py

Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from typing import Tuple
22

33
import torch
4-
import numpy as np
5-
from torch_scatter import scatter_add
64
from torch_sparse.tensor import SparseTensor
75

86

9-
def subgraph(src: SparseTensor,
10-
node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]:
7+
def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
8+
) -> Tuple[SparseTensor, torch.Tensor]:
119
row, col, value = src.coo()
1210
rowptr = src.storage.rowptr()
1311

@@ -28,58 +26,4 @@ def subgraph(src: SparseTensor,
2826
return out, edge_index
2927

3028

31-
def sample_node(src: SparseTensor,
32-
num_nodes: int) -> Tuple[SparseTensor, torch.Tensor]:
33-
row, col, _ = src.coo()
34-
35-
inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1)
36-
inv_in_deg[inv_in_deg == float('inf')] = 0
37-
38-
prob = inv_in_deg[col]
39-
prob.mul_(prob)
40-
41-
prob = scatter_add(prob, row, dim=0, dim_size=src.size(0))
42-
prob.div_(prob.sum())
43-
44-
node_idx = prob.multinomial(num_nodes, replacement=True).unique()
45-
46-
return src.permute(node_idx), node_idx
47-
48-
49-
def sample_edge(src: SparseTensor,
50-
num_edges: int) -> Tuple[SparseTensor, torch.Tensor]:
51-
52-
row, col, _ = src.coo()
53-
54-
inv_out_deg = src.storage.rowcount().to(torch.float).pow_(-1)
55-
inv_out_deg[inv_out_deg == float('inf')] = 0
56-
inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1)
57-
inv_in_deg[inv_in_deg == float('inf')] = 0
58-
59-
prob = inv_out_deg[row] + inv_in_deg[col]
60-
prob.div_(prob.sum())
61-
62-
edge_idx = prob.multinomial(num_edges, replacement=True)
63-
node_idx = col[edge_idx].unique()
64-
65-
return src.permute(node_idx), node_idx
66-
67-
68-
def sample_rw(src: SparseTensor, num_root_nodes: int,
69-
walk_length: int) -> Tuple[SparseTensor, torch.Tensor]:
70-
71-
rowptr, col, _ = src.csr()
72-
73-
start = np.random.choice(src.size(0), size=num_root_nodes, replace=False)
74-
start = torch.from_numpy(start).to(src.device(), torch.long)
75-
76-
out = torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
77-
78-
node_idx = out.flatten().unique()
79-
80-
return src.permute(node_idx), node_idx
81-
82-
83-
SparseTensor.sample_node = sample_node
84-
SparseTensor.sample_edge = sample_edge
85-
SparseTensor.sample_rw = sample_rw
29+
SparseTensor.saint_subgraph = saint_subgraph

0 commit comments

Comments
 (0)