Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions csrc/cpu/metis_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#include "utils.h"

torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
Expand Down Expand Up @@ -66,8 +66,8 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
// --partitions64bit
torch::Tensor
mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive, int64_t num_workers) {
#ifdef WITH_MTMETIS
CHECK_CPU(rowptr);
Expand Down
8 changes: 4 additions & 4 deletions csrc/cpu/metis_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#include "../extensions.h"

torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive);

torch::Tensor
mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive, int64_t num_workers);
6 changes: 3 additions & 3 deletions csrc/cpu/relabel_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
return std::make_tuple(out_col, out_idx);
}

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
std::tuple<torch::Tensor, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite) {

CHECK_CPU(rowptr);
Expand Down Expand Up @@ -79,7 +79,7 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
auto out_col = torch::empty({offset}, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();

torch::optional<torch::Tensor> out_value = torch::nullopt;
std::optional<torch::Tensor> out_value = std::nullopt;
if (optional_value.has_value()) {
out_value = torch::empty({offset}, optional_value.value().options());

Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/relabel_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
torch::Tensor idx);

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
std::tuple<torch::Tensor, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite);
6 changes: 3 additions & 3 deletions csrc/cpu/spmm_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include "reducer.h"
#include "utils.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
Expand All @@ -29,7 +29,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());

torch::optional<torch::Tensor> arg_out = torch::nullopt;
std::optional<torch::Tensor> arg_out = std::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/spmm_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);

torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ inline int64_t uniform_randint(int64_t high) {

inline torch::Tensor
choice(int64_t population, int64_t num_samples, bool replace = false,
torch::optional<torch::Tensor> weight = torch::nullopt) {
std::optional<torch::Tensor> weight = std::nullopt) {

if (population == 0 || num_samples == 0)
return torch::empty({0}, at::kLong);
Expand Down
6 changes: 3 additions & 3 deletions csrc/cuda/spmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
}
}

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {

CHECK_CUDA(rowptr);
Expand All @@ -115,7 +115,7 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());

torch::optional<torch::Tensor> arg_out = torch::nullopt;
std::optional<torch::Tensor> arg_out = std::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
Expand Down
4 changes: 2 additions & 2 deletions csrc/cuda/spmm_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);

torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
Expand Down
12 changes: 6 additions & 6 deletions csrc/metis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
#endif

SPARSE_API torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
Expand All @@ -25,14 +25,14 @@ SPARSE_API torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, optional_value, torch::nullopt, num_parts,
return partition_cpu(rowptr, col, optional_value, std::nullopt, num_parts,
recursive);
}
}

SPARSE_API torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
Expand All @@ -47,8 +47,8 @@ SPARSE_API torch::Tensor partition2(torch::Tensor rowptr, torch::Tensor col,
}

SPARSE_API torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive,
int64_t num_workers) {
if (rowptr.device().is_cuda()) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/relabel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ SPARSE_API std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
}
}

SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
SPARSE_API std::tuple<torch::Tensor, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
Expand Down
36 changes: 18 additions & 18 deletions csrc/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@ SPARSE_API torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E);

SPARSE_API torch::Tensor
partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, int64_t num_parts,
std::optional<torch::Tensor> optional_value, int64_t num_parts,
bool recursive);

SPARSE_API torch::Tensor
partition2(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive);

SPARSE_API torch::Tensor
mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
std::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive, int64_t num_workers);

SPARSE_API std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
torch::Tensor idx);

SPARSE_API std::tuple<torch::Tensor, torch::Tensor,
torch::optional<torch::Tensor>, torch::Tensor>
std::optional<torch::Tensor>, torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
std::optional<torch::Tensor> optional_value,
torch::Tensor idx, bool bipartite);

SPARSE_API torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
Expand All @@ -52,25 +52,25 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t num_neighbors, bool replace);

SPARSE_API torch::Tensor spmm_sum(torch::optional<torch::Tensor> opt_row,
SPARSE_API torch::Tensor spmm_sum(std::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
std::optional<torch::Tensor> opt_value,
std::optional<torch::Tensor> opt_colptr,
std::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat);

SPARSE_API torch::Tensor spmm_mean(torch::optional<torch::Tensor> opt_row,
SPARSE_API torch::Tensor spmm_mean(std::optional<torch::Tensor> opt_row,
torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value,
torch::optional<torch::Tensor> opt_rowcount,
torch::optional<torch::Tensor> opt_colptr,
torch::optional<torch::Tensor> opt_csr2csc,
std::optional<torch::Tensor> opt_value,
std::optional<torch::Tensor> opt_rowcount,
std::optional<torch::Tensor> opt_colptr,
std::optional<torch::Tensor> opt_csr2csc,
torch::Tensor mat);

SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_min(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
std::optional<torch::Tensor> opt_value, torch::Tensor mat);

SPARSE_API std::tuple<torch::Tensor, torch::Tensor>
spmm_max(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> opt_value, torch::Tensor mat);
std::optional<torch::Tensor> opt_value, torch::Tensor mat);
Loading
Loading