@@ -7,9 +7,10 @@ namespace monolish {
77namespace {
88// double ///////////////////
99template <typename TENS2, typename TENS3>
10- void tensor_CRS_tensor_Dense_Dtensmul_core (const double &a, const tensor::tensor_CRS<double > &A,
11- const TENS2 &B, const double &b,
12- TENS3 &C){
10+ void tensor_CRS_tensor_Dense_Dtensmul_core (const double &a,
11+ const tensor::tensor_CRS<double > &A,
12+ const TENS2 &B, const double &b,
13+ TENS3 &C) {
1314 Logger &logger = Logger::get_instance ();
1415 logger.func_in (monolish_func);
1516
@@ -21,41 +22,55 @@ void tensor_CRS_tensor_Dense_Dtensmul_core(const double &a, const tensor::tensor
2122
2223 assert (col == Bshape[0 ]);
2324 std::vector<size_t > ABshape;
24- for (size_t i= 0 ; i+ 1 < Ashape.size (); ++i){
25+ for (size_t i = 0 ; i + 1 < Ashape.size (); ++i) {
2526 ABshape.push_back (Ashape[i]);
2627 }
27- for (size_t i= 1 ; i< Bshape.size (); ++i){
28+ for (size_t i = 1 ; i < Bshape.size (); ++i) {
2829 ABshape.push_back (Bshape[i]);
2930 }
3031 assert (ABshape == Cshape);
3132
3233 std::vector<size_t > ABshape_tmp = Bshape;
3334 ABshape_tmp[0 ] = row;
3435 size_t ABshape_dim = 1 ;
35- for (size_t i= 0 ; i< ABshape_tmp.size (); ++i){
36+ for (size_t i = 0 ; i < ABshape_tmp.size (); ++i) {
3637 ABshape_dim *= ABshape_tmp[i];
3738 }
3839
3940 size_t nsum = 0 ;
4041
41- for (size_t d=0 ; d<A.row_ptrs .size (); ++d){
42- matrix::CRS<double > Amat (row, col, A.row_ptrs [d], A.col_inds [d], A.get_val ());
43- Amat.set_first (A.get_offset () + nsum);
44- nsum += A.col_inds [d].size ();
45- tensor::tensor_Dense<double > Cmat (ABshape_tmp, C.get_val ());
46- Cmat.set_first (C.get_offset () + d * ABshape_dim);
42+ for (size_t d = 0 ; d < A.row_ptrs .size (); ++d) {
43+ std::vector<double > Aval (A.col_inds [d].size ());
44+ matrix::CRS<double > Amat (row, col, A.row_ptrs [d], A.col_inds [d], Aval);
45+ std::vector<double > Cval (ABshape_dim);
46+ tensor::tensor_Dense<double > Cmat (ABshape_tmp, Cval);
47+ if (A.get_device_mem_stat ()) {
48+ Amat.send ();
49+ Cmat.send ();
50+ }
51+ internal::vcopy (Aval.size (), A.begin () + nsum, Amat.begin (),
52+ A.get_device_mem_stat ());
53+ internal::vcopy (Cval.size (), C.begin () + d * ABshape_dim, Cmat.begin (),
54+ A.get_device_mem_stat ());
4755 CRS_tensor_Dense_Dmattens_core (a, Amat, B, b, Cmat);
56+ internal::vcopy (Cval.size (), Cmat.begin (), C.begin () + d * ABshape_dim,
57+ A.get_device_mem_stat ());
58+ if (A.get_device_mem_stat ()) {
59+ Amat.recv ();
60+ Cmat.recv ();
61+ }
62+ nsum += A.col_inds [d].size ();
4863 }
4964
5065 logger.func_out ();
51-
5266}
5367
5468// float ///////////////////
5569template <typename TENS2, typename TENS3>
56- void tensor_CRS_tensor_Dense_Stensmul_core (const float &a, const tensor::tensor_CRS<float > &A,
57- const TENS2 &B, const float &b,
58- TENS3 &C){
70+ void tensor_CRS_tensor_Dense_Stensmul_core (const float &a,
71+ const tensor::tensor_CRS<float > &A,
72+ const TENS2 &B, const float &b,
73+ TENS3 &C) {
5974 Logger &logger = Logger::get_instance ();
6075 logger.func_in (monolish_func);
6176
@@ -67,30 +82,44 @@ void tensor_CRS_tensor_Dense_Stensmul_core(const float &a, const tensor::tensor_
6782
6883 assert (col == Bshape[0 ]);
6984 std::vector<size_t > ABshape;
70- for (size_t i= 0 ; i+ 1 < Ashape.size (); ++i){
85+ for (size_t i = 0 ; i + 1 < Ashape.size (); ++i) {
7186 ABshape.push_back (Ashape[i]);
7287 }
73- for (size_t i= 1 ; i< Bshape.size (); ++i){
88+ for (size_t i = 1 ; i < Bshape.size (); ++i) {
7489 ABshape.push_back (Bshape[i]);
7590 }
7691 assert (ABshape == Cshape);
7792
7893 std::vector<size_t > ABshape_tmp = Bshape;
7994 ABshape_tmp[0 ] = row;
8095 size_t ABshape_dim = 1 ;
81- for (size_t i= 0 ; i< ABshape_tmp.size (); ++i){
96+ for (size_t i = 0 ; i < ABshape_tmp.size (); ++i) {
8297 ABshape_dim *= ABshape_tmp[i];
8398 }
8499
85100 size_t nsum = 0 ;
86101
87- for (size_t d=0 ; d<A.row_ptrs .size (); ++d){
88- matrix::CRS<float > Amat (row, col, A.row_ptrs [d], A.col_inds [d], A.get_val ());
89- Amat.set_first (A.get_offset () + nsum);
90- nsum += A.col_inds [d].size ();
91- tensor::tensor_Dense<float > Cmat (ABshape_tmp, C.get_val ());
92- Cmat.set_first (C.get_offset () + d * ABshape_dim);
102+ for (size_t d = 0 ; d < A.row_ptrs .size (); ++d) {
103+ std::vector<float > Aval (A.col_inds [d].size ());
104+ matrix::CRS<float > Amat (row, col, A.row_ptrs [d], A.col_inds [d], Aval);
105+ std::vector<float > Cval (ABshape_dim);
106+ tensor::tensor_Dense<float > Cmat (ABshape_tmp, Cval);
107+ if (A.get_device_mem_stat ()) {
108+ Amat.send ();
109+ Cmat.send ();
110+ }
111+ internal::vcopy (Aval.size (), A.begin () + nsum, Amat.begin (),
112+ A.get_device_mem_stat ());
113+ internal::vcopy (Cval.size (), C.begin () + d * ABshape_dim, Cmat.begin (),
114+ A.get_device_mem_stat ());
93115 CRS_tensor_Dense_Smattens_core (a, Amat, B, b, Cmat);
116+ internal::vcopy (Cval.size (), Cmat.begin (), C.begin () + d * ABshape_dim,
117+ A.get_device_mem_stat ());
118+ if (A.get_device_mem_stat ()) {
119+ Amat.recv ();
120+ Cmat.recv ();
121+ }
122+ nsum += A.col_inds [d].size ();
94123 }
95124
96125 logger.func_out ();
0 commit comments