Skip to content

Commit eafcfe0

Browse files
authored
Fix tensor creation (#222)
* adjust tensor creation * update
1 parent 7a6c9ab commit eafcfe0

File tree

11 files changed

+23
-23
lines changed

11 files changed

+23
-23
lines changed

csrc/cpu/diag_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
1313
auto row_data = row.data_ptr<int64_t>();
1414
auto col_data = col.data_ptr<int64_t>();
1515

16-
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
16+
auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
1717
auto mask_data = mask.data_ptr<bool>();
1818

1919
int64_t r, c;

csrc/cpu/metis_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
4444
vwgt = optional_node_weight.value().data_ptr<int64_t>();
4545

4646
int64_t objval = -1;
47-
auto part = torch::empty(nvtxs, rowptr.options());
47+
auto part = torch::empty({nvtxs}, rowptr.options());
4848
auto part_data = part.data_ptr<int64_t>();
4949

5050
if (recursive) {
@@ -99,7 +99,7 @@ mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
9999

100100
mtmetis_pid_type nparts = num_parts;
101101
mtmetis_wgt_type objval = -1;
102-
auto part = torch::empty(nvtxs, rowptr.options());
102+
auto part = torch::empty({nvtxs}, rowptr.options());
103103
mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();
104104

105105
double *opts = mtmetis_init_options();

csrc/cpu/relabel_cpu.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
6464
std::unordered_map<int64_t, int64_t> n_id_map;
6565
std::unordered_map<int64_t, int64_t>::iterator it;
6666

67-
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
67+
auto out_rowptr = torch::empty({idx.numel() + 1}, rowptr.options());
6868
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
6969

7070
out_rowptr_data[0] = 0;
@@ -76,12 +76,12 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
7676
out_rowptr_data[i + 1] = offset;
7777
}
7878

79-
auto out_col = torch::empty(offset, col.options());
79+
auto out_col = torch::empty({offset}, col.options());
8080
auto out_col_data = out_col.data_ptr<int64_t>();
8181

8282
torch::optional<torch::Tensor> out_value = torch::nullopt;
8383
if (optional_value.has_value()) {
84-
out_value = torch::empty(offset, optional_value.value().options());
84+
out_value = torch::empty({offset}, optional_value.value().options());
8585

8686
AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
8787
auto value_data = optional_value.value().data_ptr<scalar_t>();

csrc/cpu/sample_cpu.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
1919
auto col_data = col.data_ptr<int64_t>();
2020
auto idx_data = idx.data_ptr<int64_t>();
2121

22-
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
22+
auto out_rowptr = torch::empty({idx.numel() + 1}, rowptr.options());
2323
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
2424
out_rowptr_data[0] = 0;
2525

@@ -117,9 +117,9 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
117117
auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
118118

119119
int64_t E = out_rowptr_data[idx.numel()];
120-
auto out_col = torch::empty(E, col.options());
120+
auto out_col = torch::empty({E}, col.options());
121121
auto out_col_data = out_col.data_ptr<int64_t>();
122-
auto out_e_id = torch::empty(E, col.options());
122+
auto out_e_id = torch::empty({E}, col.options());
123123
auto out_e_id_data = out_e_id.data_ptr<int64_t>();
124124

125125
i = 0;

csrc/cpu/spmm_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
118118
auto K = mat.size(-1);
119119
auto B = mat.numel() / (N * K);
120120

121-
auto out = torch::zeros(row.numel(), grad.options());
121+
auto out = torch::zeros({row.numel()}, grad.options());
122122

123123
auto row_data = row.data_ptr<int64_t>();
124124
auto rowptr_data = rowptr.data_ptr<int64_t>();

csrc/cpu/spspmm_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
3333

3434
if (!optional_valueA.has_value() && optional_valueB.has_value())
3535
optional_valueA =
36-
torch::ones(colA.numel(), optional_valueB.value().options());
36+
torch::ones({colA.numel()}, optional_valueB.value().options());
3737

3838
if (!optional_valueB.has_value() && optional_valueA.has_value())
3939
optional_valueB =
40-
torch::ones(colB.numel(), optional_valueA.value().options());
40+
torch::ones({colB.numel()}, optional_valueA.value().options());
4141

4242
auto scalar_type = torch::ScalarType::Float;
4343
if (optional_valueA.has_value())

csrc/cpu/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
6161
return torch::multinomial(weight.value(), num_samples, replace);
6262

6363
if (replace) {
64-
const auto out = torch::empty(num_samples, at::kLong);
64+
const auto out = torch::empty({num_samples}, at::kLong);
6565
auto *out_data = out.data_ptr<int64_t>();
6666
for (int64_t i = 0; i < num_samples; i++) {
6767
out_data[i] = uniform_randint(population);
@@ -72,7 +72,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
7272
// Sample without replacement via Robert Floyd algorithm:
7373
// https://www.nowherenearithaca.com/2013/05/
7474
// robert-floyds-tiny-and-beautiful.html
75-
const auto out = torch::empty(num_samples, at::kLong);
75+
const auto out = torch::empty({num_samples}, at::kLong);
7676
auto *out_data = out.data_ptr<int64_t>();
7777
std::unordered_set<int64_t> samples;
7878
for (int64_t i = population - num_samples; i < population; i++) {

csrc/cuda/convert_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
2727
CHECK_CUDA(ind);
2828
cudaSetDevice(ind.get_device());
2929

30-
auto out = torch::empty(M + 1, ind.options());
30+
auto out = torch::empty({M + 1}, ind.options());
3131

3232
if (ind.numel() == 0)
3333
return out.zero_();
@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
5757
CHECK_CUDA(ptr);
5858
cudaSetDevice(ptr.get_device());
5959

60-
auto out = torch::empty(E, ptr.options());
60+
auto out = torch::empty({E}, ptr.options());
6161
auto ptr_data = ptr.data_ptr<int64_t>();
6262
auto out_data = out.data_ptr<int64_t>();
6363
auto stream = at::cuda::getCurrentCUDAStream();

csrc/cuda/diag_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
5151
auto row_data = row.data_ptr<int64_t>();
5252
auto col_data = col.data_ptr<int64_t>();
5353

54-
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
54+
auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
5555
auto mask_data = mask.data_ptr<bool>();
5656

5757
if (E == 0)

csrc/cuda/spmm_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
213213
auto B = mat.numel() / (N * K);
214214
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
215215

216-
auto out = torch::zeros(row.numel(), grad.options());
216+
auto out = torch::zeros({row.numel()}, grad.options());
217217

218218
auto row_data = row.data_ptr<int64_t>();
219219
auto rowptr_data = rowptr.data_ptr<int64_t>();

0 commit comments

Comments
 (0)