@@ -337,6 +337,8 @@ namespace SOFIE{
337337 auto m = (fAttrTransA ? fShapeA [dimA-1 ].GetVal () : fShapeA [dimA-2 ].GetVal ());
338338 auto n = (fAttrTransB ? fShapeB [dimB-2 ].GetVal () : fShapeB [dimB-1 ].GetVal ());
339339 auto k = (fAttrTransA ? fShapeA [dimA-2 ].GetVal () : fShapeA [dimA-1 ].GetVal ());
340+ // size of A: if (trasposeA) is m*k else k*m
341+ // size of B n*k
340342 std::vector<Dim> sY = {fShapeY [dimY-2 ], fShapeY [dimY-1 ]};
341343 // extra dimensions in case of stacked MatMul
342344 std::vector<Dim> sA ;
@@ -371,24 +373,49 @@ namespace SOFIE{
371373 // include MatMul case where we stack the Gemm operations
372374 // exclude case where we have only 1's in the additional dims
373375 bool doStackMul = dimY > 2 && ( fIsDynamic || std::stoi (lengthExtra) > 1 );
376+ // compute input offset for stack multiplications
377+ std::string lengthExtra_A;
378+ std::string lengthExtra_B;
379+ std::string increment_A;
380+ std::string increment_B;
381+
382+ if (doStackMul) {
383+ std::vector<Dim> sA (fShapeA .begin (), fShapeA .begin ()+dimA-2 );
384+ std::vector<Dim> sB (fShapeB .begin (), fShapeB .begin ()+dimB-2 );
385+ std::vector<Dim> mA = {fShapeA [dimA-2 ], fShapeA [dimA-1 ]};
386+ std::vector<Dim> mB = {fShapeA [dimB-2 ], fShapeB [dimB-1 ]};
387+ lengthExtra_A = ConvertDimShapeToLength (sA );
388+ lengthExtra_B = ConvertDimShapeToLength (sB );
389+ // size of A performing matmul is m*k and n*k for B
390+ increment_A = ConvertDimShapeToLength (mA );
391+ increment_B = ConvertDimShapeToLength (mB );
392+ }
393+ bool extraA = (doStackMul && lengthExtra_A != " 1" );
394+ bool extraB = (doStackMul && lengthExtra_B != " 1" );
374395 if (doStackMul) {
375- out << SP << " size_t " << opName << " _yoffset = 0;\n " ; // needed if we stack the gemm operations
376- out << SP << " for (int i = 0; i < " << lengthExtra << " ; i++){\n " ;
396+ out << SP << " size_t " << opName << " _y_offset = 0;\n " ; // needed if we stack the gemm operations
397+ if (extraA)
398+ out << SP << " size_t " << opName << " _A_offset = 0;\n " ;
399+ if (extraB)
400+ out << SP << " size_t " << opName << " _B_offset = 0;\n " ;
401+ out << SP << " for (size_t i = 0; i < " << lengthExtra << " ; i++){\n " ;
377402 out << SP;
378403 }
379404
380405 if (fType == " float" ){
381406
382407 out << SP << " TMVA::Experimental::SOFIE::Gemm_Call("
383408 << " tensor_" << fNY ;
384- if (doStackMul) out << " + " << opName << " _yoffset " ;
409+ if (doStackMul) out << " + " << opName << " _y_offset " ;
385410 out << " , "
386411 << (fAttrTransB ? " true, " : " false, " )
387412 << (fAttrTransA ? " true, " : " false, " )
388413 << n << " , " << m << " , " << k << " , " ;
389- out << std::setprecision (std::numeric_limits<float >::max_digits10) << fAttrAlpha << " ," ;
390- out << " tensor_" << fNB << " , " << " tensor_" << fNA << " , " ;
391- out << std::setprecision (std::numeric_limits<float >::max_digits10) << fAttrBeta << " ," ;
414+ out << std::setprecision (std::numeric_limits<float >::max_digits10) << fAttrAlpha << " , tensor_" << fNB ;
415+ if (extraB) out << " + " << opName << " _B_offset" ;
416+ out << " , tensor_" << fNA ;
417+ if (extraA) out << " + " << opName << " _A_offset" ;
418+ out << " , " << std::setprecision (std::numeric_limits<float >::max_digits10) << fAttrBeta << " ," ;
392419 // in the case of bias
393420 if (!fNC .empty ())
394421 out << " tensor_" << fNC ;
@@ -404,7 +431,12 @@ namespace SOFIE{
404431 }
405432
406433 if (doStackMul) {
407- out << SP << SP << opName << " _yoffset += " << lengthGemm << " ;\n " ;
434+ out << SP << SP << opName << " _y_offset += " << lengthGemm << " ;\n " ;
435+ if (lengthExtra_A != " 1" )
436+ out << SP << SP << opName << " _A_offset += " << increment_A << " ;\n " ;
437+ if (lengthExtra_B != " 1" )
438+ out << SP << SP << opName << " _B_offset += " << increment_B << " ;\n " ;
439+
408440 out << " }\n " ; // end of loop on the stacked multiplications
409441 }
410442
0 commit comments