11#include < ATen/ATen.h>
2-
32#include < cusparse.h>
43
4+ #include " compat.cuh"
5+
56#define THREADS 1024
67#define BLOCKS (N ) (N + THREADS - 1 ) / THREADS
78
@@ -51,18 +52,18 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
5152
5253 // Convert A to CSR format.
5354 auto row_ptrA = at::empty (m + 1 , indexA.options ());
54- cusparseXcoo2csr (cusparse_handle, indexA[0 ].data <int >(), nnzA, k,
55- row_ptrA.data <int >(), CUSPARSE_INDEX_BASE_ZERO);
55+ cusparseXcoo2csr (cusparse_handle, indexA[0 ].DATA_PTR <int >(), nnzA, k,
56+ row_ptrA.DATA_PTR <int >(), CUSPARSE_INDEX_BASE_ZERO);
5657 auto colA = indexA[1 ];
57- cudaMemcpy (row_ptrA.data <int >() + m, &nnzA, sizeof (int ),
58+ cudaMemcpy (row_ptrA.DATA_PTR <int >() + m, &nnzA, sizeof (int ),
5859 cudaMemcpyHostToDevice);
5960
6061 // Convert B to CSR format.
6162 auto row_ptrB = at::empty (k + 1 , indexB.options ());
62- cusparseXcoo2csr (cusparse_handle, indexB[0 ].data <int >(), nnzB, k,
63- row_ptrB.data <int >(), CUSPARSE_INDEX_BASE_ZERO);
63+ cusparseXcoo2csr (cusparse_handle, indexB[0 ].DATA_PTR <int >(), nnzB, k,
64+ row_ptrB.DATA_PTR <int >(), CUSPARSE_INDEX_BASE_ZERO);
6465 auto colB = indexB[1 ];
65- cudaMemcpy (row_ptrB.data <int >() + k, &nnzB, sizeof (int ),
66+ cudaMemcpy (row_ptrB.DATA_PTR <int >() + k, &nnzB, sizeof (int ),
6667 cudaMemcpyHostToDevice);
6768
6869 cusparseMatDescr_t descr = 0 ;
@@ -74,22 +75,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
7475 auto row_ptrC = at::empty (m + 1 , indexB.options ());
7576 cusparseXcsrgemmNnz (cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
7677 CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
77- row_ptrA.data <int >(), colA.data <int >(), descr, nnzB ,
78- row_ptrB.data <int >(), colB.data <int >(), descr ,
79- row_ptrC.data <int >(), &nnzC);
78+ row_ptrA.DATA_PTR <int >(), colA.DATA_PTR <int >(), descr,
79+ nnzB, row_ptrB.DATA_PTR <int >(), colB.DATA_PTR <int >(),
80+ descr, row_ptrC.DATA_PTR <int >(), &nnzC);
8081 auto colC = at::empty (nnzC, indexA.options ());
8182 auto valueC = at::empty (nnzC, valueA.options ());
8283
8384 CSRGEMM (valueC.scalar_type (), cusparse_handle,
8485 CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m,
85- n, k, descr, nnzA, valueA.data <scalar_t >(), row_ptrA.data <int >(),
86- colA.data <int >(), descr, nnzB, valueB.data <scalar_t >(),
87- row_ptrB.data <int >(), colB.data <int >(), descr,
88- valueC.data <scalar_t >(), row_ptrC.data <int >(), colC.data <int >());
86+ n, k, descr, nnzA, valueA.DATA_PTR <scalar_t >(),
87+ row_ptrA.DATA_PTR <int >(), colA.DATA_PTR <int >(), descr, nnzB,
88+ valueB.DATA_PTR <scalar_t >(), row_ptrB.DATA_PTR <int >(),
89+ colB.DATA_PTR <int >(), descr, valueC.DATA_PTR <scalar_t >(),
90+ row_ptrC.DATA_PTR <int >(), colC.DATA_PTR <int >());
8991
9092 auto rowC = at::empty (nnzC, indexA.options ());
91- cusparseXcsr2coo (cusparse_handle, row_ptrC.data <int >(), nnzC, m,
92- rowC.data <int >(), CUSPARSE_INDEX_BASE_ZERO);
93+ cusparseXcsr2coo (cusparse_handle, row_ptrC.DATA_PTR <int >(), nnzC, m,
94+ rowC.DATA_PTR <int >(), CUSPARSE_INDEX_BASE_ZERO);
9395
9496 auto indexC = at::stack ({rowC, colC}, 0 ).toType (at::kLong );
9597
@@ -154,9 +156,10 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
154156
155157 AT_DISPATCH_FLOATING_TYPES (valueA.scalar_type (), " spspmm_bw" , [&] {
156158 spspmm_bw_kernel<scalar_t ><<<BLOCKS(value.numel()), THREADS>>> (
157- index.data <int64_t >(), value.data <scalar_t >(), rowA.data <int64_t >(),
158- colA.data <int64_t >(), valueA.data <scalar_t >(), rowB.data <int64_t >(),
159- colB.data <int64_t >(), valueB.data <scalar_t >(), value.numel ());
159+ index.DATA_PTR <int64_t >(), value.DATA_PTR <scalar_t >(),
160+ rowA.DATA_PTR <int64_t >(), colA.DATA_PTR <int64_t >(),
161+ valueA.DATA_PTR <scalar_t >(), rowB.DATA_PTR <int64_t >(),
162+ colB.DATA_PTR <int64_t >(), valueB.DATA_PTR <scalar_t >(), value.numel ());
160163 });
161164
162165 return value;
0 commit comments