Skip to content

Commit 78d9af4

Browse files
committed
sample adj
1 parent d3ae9f1 commit 78d9af4

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

torch_sparse/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
for library in [
99
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
10-
'_saint', '_padding'
10+
'_saint', '_padding', '_sample'
1111
]:
1212
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
1313
library, [osp.dirname(__file__)]).origin)
@@ -50,7 +50,7 @@
5050
from .bandwidth import reverse_cuthill_mckee # noqa
5151
from .saint import saint_subgraph # noqa
5252
from .padding import padded_index, padded_index_select # noqa
53-
from .sample import sample # noqa
53+
from .sample import sample, sample_adj # noqa
5454

5555
from .convert import to_torch_sparse, from_torch_sparse # noqa
5656
from .convert import to_scipy, from_scipy # noqa

torch_sparse/sample.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple
22

33
import torch
44
from torch_sparse.tensor import SparseTensor
@@ -22,4 +22,21 @@ def sample(src: SparseTensor, num_neighbors: int,
2222
return col[rand]
2323

2424

25+
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
26+
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
27+
28+
rowptr, col, _ = src.csr()
29+
rowcount = src.storage.rowcount()
30+
31+
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
32+
rowptr, col, rowcount, subset, num_neighbors, replace)
33+
34+
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=e_id,
35+
sparse_sizes=(subset.size(0), n_id.size(0)),
36+
is_sorted=True)
37+
38+
return out, n_id
39+
40+
2541
SparseTensor.sample = sample
42+
SparseTensor.sample_adj = sample_adj

0 commit comments

Comments
 (0)