Skip to content

Commit eb8c2ec

Browse files
committed
cleanup
1 parent 45a4d98 commit eb8c2ec

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

torch_sparse/metis.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
2222
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
2323
weighted=False
2424
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
25+
26+
assert num_parts >= 1
27+
if num_parts == 1:
28+
partptr = torch.tensor([0, src.size(0)], device=src.device())
29+
perm = torch.arange(src.size(0), device=src.device())
30+
return src, partptr, perm
31+
2532
rowptr, col, value = src.csr()
2633
rowptr, col = rowptr.cpu(), col.cpu()
2734

@@ -33,13 +40,8 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
3340
else:
3441
value = None
3542

36-
if num_parts > 1:
37-
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
38-
num_parts, recursive)
39-
elif num_parts == 1:
40-
cluster = torch.zeros((src.size(0)), dtype=torch.long)
41-
else:
42-
raise ValueError
43+
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
44+
recursive)
4345
cluster = cluster.to(src.device())
4446

4547
cluster, perm = cluster.sort()

0 commit comments

Comments
 (0)