|
1 | 1 | import pytest |
2 | 2 | import torch |
3 | 3 | from torch_sparse.tensor import SparseTensor |
4 | | -from torch_sparse.saint import subgraph |
5 | 4 |
|
6 | 5 | from .utils import devices |
7 | 6 |
|
8 | 7 |
|
9 | 8 | @pytest.mark.parametrize('device', devices) |
10 | | -def test_subgraph(device): |
| 9 | +def test_saint_subgraph(device): |
11 | 10 | row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) |
12 | 11 | col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) |
13 | 12 | adj = SparseTensor(row=row, col=col).to(device) |
14 | 13 | node_idx = torch.tensor([0, 1, 2]) |
15 | 14 |
|
16 | | - adj, edge_index = subgraph(adj, node_idx) |
17 | | - |
18 | | - |
19 | | -@pytest.mark.parametrize('device', devices) |
20 | | -def test_sample_node(device): |
21 | | - row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) |
22 | | - col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) |
23 | | - adj = SparseTensor(row=row, col=col).to(device) |
24 | | - |
25 | | - adj, perm = adj.sample_node(num_nodes=3) |
26 | | - |
27 | | - |
28 | | -@pytest.mark.parametrize('device', devices) |
29 | | -def test_sample_edge(device): |
30 | | - row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) |
31 | | - col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) |
32 | | - adj = SparseTensor(row=row, col=col).to(device) |
33 | | - |
34 | | - adj, perm = adj.sample_edge(num_edges=3) |
35 | | - |
36 | | - |
37 | | -@pytest.mark.parametrize('device', devices) |
38 | | -def test_sample_rw(device): |
39 | | - row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) |
40 | | - col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) |
41 | | - adj = SparseTensor(row=row, col=col).to(device) |
42 | | - |
43 | | - adj, perm = adj.sample_rw(num_root_nodes=3, walk_length=2) |
| 15 | + adj, edge_index = adj.saint_subgraph(node_idx) |
0 commit comments