Skip to content

Commit 48eed83

Browse files
committed
backward compatibility with torch-sparse==0.6.9
1 parent d42a18a commit 48eed83

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

csrc/metis.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,28 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
1313

1414
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
1515
torch::optional<torch::Tensor> optional_value,
16-
torch::optional<torch::Tensor> optional_node_weight,
1716
int64_t num_parts, bool recursive) {
1817
if (rowptr.device().is_cuda()) {
1918
#ifdef WITH_CUDA
2019
AT_ERROR("No CUDA version supported");
2120
#else
2221
AT_ERROR("Not compiled with CUDA support");
22+
#endif
23+
} else {
24+
return partition_cpu(rowptr, col, optional_value, nullptr, num_parts,
25+
recursive);
26+
}
27+
}
28+
29+
torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
30+
torch::optional<torch::Tensor> optional_value,
31+
torch::optional<torch::Tensor> optional_node_weight,
32+
int64_t num_parts, bool recursive) {
33+
if (rowptr.device().is_cuda()) {
34+
#ifdef WITH_CUDA
35+
AT_ERROR("No CUDA version supported");
36+
#else
37+
AT_ERROR("Not compiled with CUDA support");
2338
#endif
2439
} else {
2540
return partition_cpu(rowptr, col, optional_value, optional_node_weight,
@@ -46,4 +61,5 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
4661

4762
static auto registry = torch::RegisterOperators()
4863
.op("torch_sparse::partition", &partition)
64+
.op("torch_sparse::partition2", &partition2)
4965
.op("torch_sparse::mt_partition", &mt_partition);

csrc/sparse.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
1111
torch::optional<torch::Tensor> optional_value,
1212
int64_t num_parts, bool recursive);
1313

14+
torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
15+
torch::optional<torch::Tensor> optional_value,
16+
torch::optional<torch::Tensor> optional_node_weight,
17+
int64_t num_parts, bool recursive);
18+
1419
torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
1520
torch::optional<torch::Tensor> optional_value,
1621
int64_t num_parts, bool recursive);

torch_sparse/metis.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ def partition(
4646
node_weight = node_weight.view(-1).detach().cpu()
4747
if node_weight.is_floating_point():
4848
node_weight = weight2metis(node_weight)
49-
50-
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, node_weight,
51-
num_parts, recursive)
49+
cluster = torch.ops.torch_sparse.partition2(rowptr, col, value,
50+
node_weight, num_parts,
51+
recursive)
52+
else:
53+
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
54+
num_parts, recursive)
5255
cluster = cluster.to(src.device())
5356

5457
cluster, perm = cluster.sort()

0 commit comments

Comments
 (0)