@@ -13,13 +13,28 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
1313
1414torch::Tensor partition (torch::Tensor rowptr, torch::Tensor col,
1515 torch::optional<torch::Tensor> optional_value,
16- torch::optional<torch::Tensor> optional_node_weight,
1716 int64_t num_parts, bool recursive) {
1817 if (rowptr.device ().is_cuda ()) {
1918#ifdef WITH_CUDA
2019 AT_ERROR (" No CUDA version supported" );
2120#else
2221 AT_ERROR (" Not compiled with CUDA support" );
22+ #endif
23+ } else {
24+ return partition_cpu (rowptr, col, optional_value, nullptr , num_parts,
25+ recursive);
26+ }
27+ }
28+
29+ torch::Tensor partition2 (torch::Tensor rowptr, torch::Tensor col,
30+ torch::optional<torch::Tensor> optional_value,
31+ torch::optional<torch::Tensor> optional_node_weight,
32+ int64_t num_parts, bool recursive) {
33+ if (rowptr.device ().is_cuda ()) {
34+ #ifdef WITH_CUDA
35+ AT_ERROR (" No CUDA version supported" );
36+ #else
37+ AT_ERROR (" Not compiled with CUDA support" );
2338#endif
2439 } else {
2540 return partition_cpu (rowptr, col, optional_value, optional_node_weight,
@@ -46,4 +61,5 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
4661
4762static auto registry = torch::RegisterOperators()
4863 .op(" torch_sparse::partition" , &partition)
64+ .op(" torch_sparse::partition2" , &partition2)
4965 .op(" torch_sparse::mt_partition" , &mt_partition);
0 commit comments