Skip to content

Commit 468aea5

Browse files
authored
Merge pull request #68 from james77777778/master
Fixed the bug from metis if num_parts == 1
2 parents 6884191 + eb8c2ec commit 468aea5

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

test/test_metis.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ def test_metis(device):
3030
weighted=False)
3131
assert partptr.numel() == 3
3232
assert perm.numel() == 6
33+
34+
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
35+
weighted=True)
36+
assert partptr.numel() == 2
37+
assert perm.numel() == 6

torch_sparse/metis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
2222
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
2323
weighted=False
2424
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
25+
26+
assert num_parts >= 1
27+
if num_parts == 1:
28+
partptr = torch.tensor([0, src.size(0)], device=src.device())
29+
perm = torch.arange(src.size(0), device=src.device())
30+
return src, partptr, perm
31+
2532
rowptr, col, value = src.csr()
2633
rowptr, col = rowptr.cpu(), col.cpu()
2734

0 commit comments

Comments
 (0)