Skip to content

Commit d3a94a2

Browse files
committed
add relabeling
1 parent ed3de95 commit d3a94a2

File tree

4 files changed

+74
-1
lines changed

4 files changed

+74
-1
lines changed

csrc/cpu/relabel_cpu.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "relabel_cpu.h"
2+
3+
#include "utils.h"
4+
5+
std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
6+
torch::Tensor idx) {
7+
8+
CHECK_CPU(col);
9+
CHECK_CPU(idx);
10+
CHECK_INPUT(idx.dim() == 1);
11+
12+
auto col_data = col.data_ptr<int64_t>();
13+
auto idx_data = idx.data_ptr<int64_t>();
14+
15+
std::vector<int64_t> cols;
16+
std::vector<int64_t> n_ids;
17+
std::unordered_map<int64_t, int64_t> n_id_map;
18+
19+
int64_t i;
20+
for (int64_t n = 0; n < idx.size(0); n++) {
21+
i = idx_data[n];
22+
n_id_map[i] = n;
23+
n_ids.push_back(i);
24+
}
25+
26+
int64_t c;
27+
for (int64_t e = 0; e < col.size(0); e++) {
28+
c = col_data[e];
29+
30+
if (n_id_map.count(c) == 0) {
31+
n_id_map[c] = n_ids.size();
32+
n_ids.push_back(c);
33+
}
34+
35+
cols.push_back(n_id_map[c]);
36+
}
37+
38+
int64_t n_len = n_ids.size(), e_len = cols.size();
39+
auto out_col = torch::from_blob(cols.data(), {e_len}, col.options()).clone();
40+
auto out_idx = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone();
41+
42+
return std::make_tuple(out_col, out_idx);
43+
}

csrc/cpu/relabel_cpu.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
6+
torch::Tensor idx);

csrc/relabel.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <Python.h>
2+
#include <torch/script.h>
3+
4+
#include "cpu/relabel_cpu.h"
5+
6+
#ifdef _WIN32
7+
PyMODINIT_FUNC PyInit__relabel(void) { return NULL; }
8+
#endif
9+
10+
std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
11+
torch::Tensor idx) {
12+
if (col.device().is_cuda()) {
13+
#ifdef WITH_CUDA
14+
AT_ERROR("No CUDA version supported");
15+
#else
16+
AT_ERROR("Not compiled with CUDA support");
17+
#endif
18+
} else {
19+
return relabel_cpu(col, idx);
20+
}
21+
}
22+
23+
static auto registry =
24+
torch::RegisterOperators().op("torch_sparse::relabel", &relabel);

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
for library in [
99
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
10-
'_saint', '_sample'
10+
'_saint', '_sample', '_relabel'
1111
]:
1212
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
1313
library, [osp.dirname(__file__)]).origin)

0 commit comments

Comments
 (0)