Skip to content

Commit 45a4d98

Browse files
committed
fix the bug from metis if num_parts == 1
1 parent 6884191 commit 45a4d98

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

test/test_metis.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ def test_metis(device):
3030
weighted=False)
3131
assert partptr.numel() == 3
3232
assert perm.numel() == 6
33+
34+
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
35+
weighted=True)
36+
assert partptr.numel() == 2
37+
assert perm.numel() == 6

torch_sparse/metis.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
3333
else:
3434
value = None
3535

36-
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
37-
recursive)
36+
if num_parts > 1:
37+
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
38+
num_parts, recursive)
39+
elif num_parts == 1:
40+
cluster = torch.zeros((src.size(0)), dtype=torch.long)
41+
else:
42+
raise ValueError
3843
cluster = cluster.to(src.device())
3944

4045
cluster, perm = cluster.sort()

0 commit comments

Comments
 (0)