Skip to content

Commit 2bcc13e

Browse files
committed
added header file
1 parent a9d4d46 commit 2bcc13e

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

csrc/sparse.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
int64_t cuda_version();
6+
7+
torch::Tensor ind2ptr(torch::Tensor ind, int64_t M);
8+
torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E);
9+
10+
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
11+
torch::optional<torch::Tensor> optional_value,
12+
int64_t num_parts, bool recursive);
13+
14+
std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
15+
torch::Tensor idx);
16+
17+
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
18+
torch::Tensor start, int64_t walk_length);
19+
20+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
21+
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
22+
torch::Tensor col);
23+
24+
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
25+
torch::Tensor idx, int64_t num_neighbors, bool replace);
26+
27+
torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
28+
torch::Tensor rowptr, torch::Tensor col,
29+
torch::optional<torch::Tensor> opt_value,
30+
torch::optional<torch::Tensor> opt_colptr,
31+
torch::optional<torch::Tensor> opt_csr2csc,
32+
torch::Tensor mat);
33+
34+
torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
35+
torch::Tensor rowptr, torch::Tensor col,
36+
torch::optional<torch::Tensor> opt_value,
37+
torch::optional<torch::Tensor> opt_rowcount,
38+
torch::optional<torch::Tensor> opt_colptr,
39+
torch::optional<torch::Tensor> opt_csr2csc,
40+
torch::Tensor mat);
41+
42+
std::tuple<torch::Tensor, torch::Tensor>
43+
spmm_min(torch::Tensor rowptr, torch::Tensor col,
44+
torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
45+
46+
std::tuple<torch::Tensor, torch::Tensor>
47+
spmm_max(torch::Tensor rowptr, torch::Tensor col,
48+
torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
49+
50+
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
51+
spspmm_sum(torch::Tensor rowptrA, torch::Tensor colA,
52+
torch::optional<torch::Tensor> optional_valueA,
53+
torch::Tensor rowptrB, torch::Tensor colB,
54+
torch::optional<torch::Tensor> optional_valueB, int64_t K);

0 commit comments

Comments
 (0)