@@ -48,80 +48,56 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
4848 auto rowptrB_data = rowptrB.data_ptr <int64_t >();
4949 auto colB_data = colB.data_ptr <int64_t >();
5050
51- // Pass 1: Compute CSR row pointer.
5251 auto rowptrC = torch::empty_like (rowptrA);
5352 auto rowptrC_data = rowptrC.data_ptr <int64_t >();
5453 rowptrC_data[0 ] = 0 ;
5554
56- std::vector<int64_t > mask (K, -1 );
57- int64_t nnz = 0 , row_nnz, rowA_start, rowA_end, rowB_start, rowB_end, cA, cB;
58- for (auto n = 0 ; n < rowptrA.numel () - 1 ; n++) {
59- row_nnz = 0 ;
60-
61- for (auto eA = rowptrA_data[n]; eA < rowptrA_data[n + 1 ]; eA++) {
62- cA = colA_data[eA];
63- for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1 ]; eB++) {
64- cB = colB_data[eB];
65- if (mask[cB] != n) {
66- mask[cB] = n;
67- row_nnz++;
68- }
69- }
70- }
71-
72- nnz += row_nnz;
73- rowptrC_data[n + 1 ] = nnz;
74- }
75-
76- // Pass 2: Compute CSR entries.
77- auto colC = torch::empty (nnz, rowptrC.options ());
78- auto colC_data = colC.data_ptr <int64_t >();
79-
55+ torch::Tensor colC;
8056 torch::optional<torch::Tensor> optional_valueC = torch::nullopt ;
81- if (optional_valueA.has_value ())
82- optional_valueC = torch::empty (nnz, optional_valueA.value ().options ());
8357
8458 AT_DISPATCH_ALL_TYPES (scalar_type, " spspmm" , [&] {
85- AT_DISPATCH_HAS_VALUE (optional_valueC , [&] {
86- scalar_t *valA_data = nullptr , *valB_data = nullptr , *valC_data = nullptr ;
59+ AT_DISPATCH_HAS_VALUE (optional_valueA , [&] {
60+ scalar_t *valA_data = nullptr , *valB_data = nullptr ;
8761 if (HAS_VALUE) {
8862 valA_data = optional_valueA.value ().data_ptr <scalar_t >();
8963 valB_data = optional_valueB.value ().data_ptr <scalar_t >();
90- valC_data = optional_valueC.value ().data_ptr <scalar_t >();
9164 }
92- scalar_t valA;
9365
94- rowA_start = 0 , nnz = 0 ;
95- std::vector<scalar_t > vals (K, 0 );
96- for ( auto n = 1 ; n < rowptrA. numel (); n++) {
97- rowA_end = rowptrA_data[n] ;
66+ int64_t nnz = 0 , cA, cB ;
67+ std::vector<scalar_t > tmp_vals (K, 0 );
68+ std::vector< int64_t > cols;
69+ std::vector< scalar_t > vals ;
9870
99- for (auto eA = rowA_start; eA < rowA_end; eA++) {
71+ for (auto rA = 0 ; rA < rowptrA.numel () - 1 ; rA++) {
72+ for (auto eA = rowptrA_data[rA]; eA < rowptrA_data[rA + 1 ]; eA++) {
10073 cA = colA_data[eA];
101- if (HAS_VALUE)
102- valA = valA_data[eA];
103-
104- rowB_start = rowptrB_data[cA], rowB_end = rowptrB_data[cA + 1 ];
105- for (auto eB = rowB_start; eB < rowB_end; eB++) {
74+ for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1 ]; eB++) {
10675 cB = colB_data[eB];
76+
10777 if (HAS_VALUE)
108- vals [cB] += valA * valB_data[eB];
78+ tmp_vals [cB] += valA_data[eA] * valB_data[eB];
10979 else
110- vals [cB] += 1 ;
80+ tmp_vals [cB]++ ;
11181 }
11282 }
11383
11484 for (auto k = 0 ; k < K; k++) {
115- if (vals [k] != 0 ) {
116- colC_data[nnz] = k ;
85+ if (tmp_vals [k] != 0 ) {
86+ cols. push_back (k) ;
11787 if (HAS_VALUE)
118- valC_data[nnz] = vals[k];
88+ vals. push_back (tmp_vals [k]) ;
11989 nnz++;
12090 }
121- vals [k] = (scalar_t )0 ;
91+ tmp_vals [k] = (scalar_t )0 ;
12292 }
93+ rowptrC_data[rA + 1 ] = nnz;
94+ }
12395
124- rowA_start = rowA_end;
96+ colC = torch::from_blob (cols.data (), {nnz}, colA.options ()).clone ();
97+ if (HAS_VALUE) {
98+ optional_valueC = torch::from_blob (vals.data (), {nnz},
99+ optional_valueA.value ().options ());
100+ optional_valueC = optional_valueC.value ().clone ();
125101 }
126102 });
127103 });
0 commit comments