Skip to content

Commit 57852a6

Browse files
committed
bandwidth implementation
1 parent 539e206 commit 57852a6

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

torch_sparse/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .cat import cat, cat_diag # noqa
4848
from .rw import random_walk # noqa
4949
from .metis import partition # noqa
50+
from .bandwidth import reverse_cuthill_mckee # noqa
5051
from .saint import saint_subgraph # noqa
5152
from .padding import padded_index, padded_index_select # noqa
5253

@@ -90,6 +91,7 @@
9091
'cat_diag',
9192
'random_walk',
9293
'partition',
94+
'reverse_cuthill_mckee',
9395
'saint_subgraph',
9496
'padded_index',
9597
'padded_index_select',

torch_sparse/bandwidth.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import scipy.sparse as sp
2+
from typing import Tuple, Optional
3+
4+
import torch
5+
from torch_sparse.tensor import SparseTensor
6+
from torch_sparse.permute import permute
7+
8+
9+
def reverse_cuthill_mckee(src: SparseTensor,
10+
is_symmetric: Optional[bool] = None
11+
) -> Tuple[SparseTensor, torch.Tensor]:
12+
13+
if is_symmetric is None:
14+
is_symmetric = src.is_symmetric()
15+
16+
if not is_symmetric:
17+
src = src.to_symmetric()
18+
19+
sp_src = src.to_scipy(layout='csr')
20+
perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy()
21+
perm = torch.from_numpy(perm).to(torch.long).to(src.device())
22+
23+
out = permute(src, perm)
24+
25+
return out, perm
26+
27+
28+
SparseTensor.reverse_cuthill_mckee = reverse_cuthill_mckee

0 commit comments

Comments
 (0)