77
88#define CSRGEMM (TYPE, ...) \
99 [&] { \
10- const at::Type &the_type = TYPE; \
11- switch (the_type.scalarType ()) { \
10+ const auto &the_type = TYPE; \
11+ (void )the_type; \
12+ at::ScalarType _st = ::detail::scalar_type (TYPE); \
13+ switch (_st) { \
1214 case at::ScalarType::Float: { \
1315 using scalar_t = float ; \
1416 return cusparseScsrgemm (__VA_ARGS__); \
1820 return cusparseDcsrgemm (__VA_ARGS__); \
1921 } \
2022 default : \
21- AT_ERROR (" Not implemented for '%s' " , the_type. toString ()); \
23+ AT_ERROR (" Not implemented for '" , toString (_st), " ' " ); \
2224 } \
2325 }()
2426
@@ -48,15 +50,15 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
4850 indexB = indexB.toType (at::kInt );
4951
5052 // Convert A to CSR format.
51- auto row_ptrA = at::empty (m + 1 , indexA.type ());
53+ auto row_ptrA = at::empty (m + 1 , indexA.options ());
5254 cusparseXcoo2csr (cusparse_handle, indexA[0 ].data <int >(), nnzA, k,
5355 row_ptrA.data <int >(), CUSPARSE_INDEX_BASE_ZERO);
5456 auto colA = indexA[1 ];
5557 cudaMemcpy (row_ptrA.data <int >() + m, &nnzA, sizeof (int ),
5658 cudaMemcpyHostToDevice);
5759
5860 // Convert B to CSR format.
59- auto row_ptrB = at::empty (k + 1 , indexB.type ());
61+ auto row_ptrB = at::empty (k + 1 , indexB.options ());
6062 cusparseXcoo2csr (cusparse_handle, indexB[0 ].data <int >(), nnzB, k,
6163 row_ptrB.data <int >(), CUSPARSE_INDEX_BASE_ZERO);
6264 auto colB = indexB[1 ];
@@ -69,23 +71,23 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
6971 cusparseSetMatIndexBase (descr, CUSPARSE_INDEX_BASE_ZERO);
7072
7173 int nnzC;
72- auto row_ptrC = at::empty (m + 1 , indexB.type ());
74+ auto row_ptrC = at::empty (m + 1 , indexB.options ());
7375 cusparseXcsrgemmNnz (cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
7476 CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
7577 row_ptrA.data <int >(), colA.data <int >(), descr, nnzB,
7678 row_ptrB.data <int >(), colB.data <int >(), descr,
7779 row_ptrC.data <int >(), &nnzC);
78- auto colC = at::empty (nnzC, indexA.type ());
79- auto valueC = at::empty (nnzC, valueA.type ());
80+ auto colC = at::empty (nnzC, indexA.options ());
81+ auto valueC = at::empty (nnzC, valueA.options ());
8082
81- CSRGEMM (valueC.type (), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE ,
82- CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA ,
83- valueA.data <scalar_t >(), row_ptrA. data < int >(), colA .data <int >(),
84- descr, nnzB, valueB .data <scalar_t >(), row_ptrB .data <int >(),
85- colB .data <int >(), descr, valueC .data <scalar_t >(),
86- row_ptrC.data <int >(), colC.data <int >());
83+ CSRGEMM (valueC.scalar_type (), cusparse_handle,
84+ CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, m ,
85+ n, k, descr, nnzA, valueA.data <scalar_t >(), row_ptrA.data <int >(),
86+ colA .data <int >(), descr, nnzB, valueB .data <scalar_t >(),
87+ row_ptrB .data <int >(), colB .data <int >(), descr ,
88+ valueC. data < scalar_t >(), row_ptrC.data <int >(), colC.data <int >());
8789
88- auto rowC = at::empty (nnzC, indexA.type ());
90+ auto rowC = at::empty (nnzC, indexA.options ());
8991 cusparseXcsr2coo (cusparse_handle, row_ptrC.data <int >(), nnzC, m,
9092 rowC.data <int >(), CUSPARSE_INDEX_BASE_ZERO);
9193
@@ -150,7 +152,7 @@ at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
150152 at::Tensor rowB, colB;
151153 std::tie (rowB, colB) = to_csr (indexB[0 ], indexB[1 ], rowB_max);
152154
153- AT_DISPATCH_FLOATING_TYPES (valueA.type (), " spspmm_bw" , [&] {
155+ AT_DISPATCH_FLOATING_TYPES (valueA.scalar_type (), " spspmm_bw" , [&] {
154156 spspmm_bw_kernel<scalar_t ><<<BLOCKS(value.numel()), THREADS>>> (
155157 index.data <int64_t >(), value.data <scalar_t >(), rowA.data <int64_t >(),
156158 colA.data <int64_t >(), valueA.data <scalar_t >(), rowB.data <int64_t >(),
0 commit comments