Skip to content

Commit de52883

Browse files
committed
pytorch 1.1.0 update
1 parent 9732a51 commit de52883

File tree

7 files changed

+27
-25
lines changed

7 files changed

+27
-25
lines changed

.travis.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ before_install:
1717
- export CC="gcc-4.9"
1818
- export CXX="g++-4.9"
1919
install:
20-
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl; fi
21-
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi
22-
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi
20+
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl; fi
21+
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp35-cp35m-linux_x86_64.whl; fi
22+
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl; fi
2323
- pip install pycodestyle
2424
- pip install flake8
2525
- pip install codecov

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
2828

2929
## Installation
3030

31-
Ensure that at least PyTorch 1.0.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
31+
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
3232

3333
```
3434
$ python -c "import torch; print(torch.__version__)"

cpu/spspmm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
3131
int64_t *rowB_data = rowB.data<int64_t>();
3232
int64_t *colB_data = colB.data<int64_t>();
3333

34-
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
34+
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
3535
scalar_t *value_data = value.data<scalar_t>();
3636
scalar_t *valueA_data = valueA.data<scalar_t>();
3737
scalar_t *valueB_data = valueB.data<scalar_t>();

cuda/spspmm_kernel.cu

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
#define CSRGEMM(TYPE, ...) \
99
[&] { \
10-
const at::Type &the_type = TYPE; \
11-
switch (the_type.scalarType()) { \
10+
const auto &the_type = TYPE; \
11+
(void)the_type; \
12+
at::ScalarType _st = ::detail::scalar_type(TYPE); \
13+
switch (_st) { \
1214
case at::ScalarType::Float: { \
1315
using scalar_t = float; \
1416
return cusparseScsrgemm(__VA_ARGS__); \
@@ -18,7 +20,7 @@
1820
return cusparseDcsrgemm(__VA_ARGS__); \
1921
} \
2022
default: \
21-
AT_ERROR("Not implemented for '%s'", the_type.toString()); \
23+
AT_ERROR("Not implemented for '", toString(_st), "'"); \
2224
} \
2325
}()
2426

@@ -48,15 +50,15 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
4850
indexB = indexB.toType(at::kInt);
4951

5052
// Convert A to CSR format.
51-
auto row_ptrA = at::empty(m + 1, indexA.type());
53+
auto row_ptrA = at::empty(m + 1, indexA.options());
5254
cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
5355
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
5456
auto colA = indexA[1];
5557
cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int),
5658
cudaMemcpyHostToDevice);
5759

5860
// Convert B to CSR format.
59-
auto row_ptrB = at::empty(k + 1, indexB.type());
61+
auto row_ptrB = at::empty(k + 1, indexB.options());
6062
cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
6163
row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
6264
auto colB = indexB[1];
@@ -69,23 +71,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
6971
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
7072

7173
int nnzC;
72-
auto row_ptrC = at::empty(m + 1, indexB.type());
74+
auto row_ptrC = at::empty(m + 1, indexB.options());
7375
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
7476
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
7577
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
7678
row_ptrB.data<int>(), colB.data<int>(), descr,
7779
row_ptrC.data<int>(), &nnzC);
78-
auto colC = at::empty(nnzC, indexA.type());
79-
auto valueC = at::empty(nnzC, valueA.type());
80+
auto colC = at::empty(nnzC, indexA.options());
81+
auto valueC = at::empty(nnzC, valueA.options());
8082

81-
CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
82-
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
83-
valueA.data<scalar_t>(), row_ptrA.data<int>(), colA.data<int>(),
84-
descr, nnzB, valueB.data<scalar_t>(), row_ptrB.data<int>(),
85-
colB.data<int>(), descr, valueC.data<scalar_t>(),
86-
row_ptrC.data<int>(), colC.data<int>());
83+
CSRGEMM(valueC.scalar_type(), cusparse_handle,
84+
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
85+
n, k, descr, nnzA, valueA.data<scalar_t>(), row_ptrA.data<int>(),
86+
colA.data<int>(), descr, nnzB, valueB.data<scalar_t>(),
87+
row_ptrB.data<int>(), colB.data<int>(), descr,
88+
valueC.data<scalar_t>(), row_ptrC.data<int>(), colC.data<int>());
8789

88-
auto rowC = at::empty(nnzC, indexA.type());
90+
auto rowC = at::empty(nnzC, indexA.options());
8991
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
9092
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
9193

@@ -150,7 +152,7 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
150152
at::Tensor rowB, colB;
151153
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
152154

153-
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
155+
AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
154156
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
155157
index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(),
156158
colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(),

cuda/unique_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
2020
at::Tensor perm;
2121
std::tie(src, perm) = src.sort();
2222

23-
auto mask = at::zeros(src.numel(), src.type().toScalarType(at::kByte));
24-
AT_DISPATCH_ALL_TYPES(src.type(), "grid_cuda_kernel", [&] {
23+
auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
24+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
2525
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
2626
src.data<scalar_t>(), mask.data<uint8_t>(), src.numel());
2727
});

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
2222
]
2323

24-
__version__ = '0.3.0'
24+
__version__ = '0.4.0'
2525
url = 'https://github.com/rusty1s/pytorch_sparse'
2626

2727
install_requires = ['scipy']

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .spmm import spmm
66
from .spspmm import spspmm
77

8-
__version__ = '0.3.0'
8+
__version__ = '0.4.0'
99

1010
__all__ = [
1111
'__version__',

0 commit comments

Comments
 (0)