Skip to content

Commit 522de76

Browse files
committed
pytorch 1.0.0 update
1 parent e58d83e commit 522de76

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

cpu/dim_apply.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/torch.h>
3+
#include <torch/extension.h>
44

55
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
66
[&] { \
@@ -17,7 +17,7 @@
1717
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
1818
\
1919
auto dims = TENSOR1.dim(); \
20-
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
20+
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
2121
auto counter = zeros.data<int64_t>(); \
2222
bool has_finished = false; \
2323
\
@@ -76,7 +76,7 @@
7676
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
7777
\
7878
auto dims = TENSOR1.dim(); \
79-
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
79+
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
8080
auto counter = zeros.data<int64_t>(); \
8181
bool has_finished = false; \
8282
\

cpu/scatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <torch/torch.h>
1+
#include <torch/extension.h>
22

33
#include "dim_apply.h"
44

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
1616
]
1717

18-
__version__ = '1.0.4'
18+
__version__ = '1.0.5'
1919
url = 'https://github.com/rusty1s/pytorch_scatter'
2020

2121
install_requires = []

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .max import scatter_max
88
from .min import scatter_min
99

10-
__version__ = '1.0.4'
10+
__version__ = '1.0.5'
1111

1212
__all__ = [
1313
'scatter_add',

0 commit comments

Comments
 (0)