Skip to content

Commit 72f683c

Browse files
committed
[tmva][sofie] Fix stacked MatMul and speedup LayerNorm
Apply also other fixes for the SOFIE tests and add a new test for StackMul
1 parent 21f3675 commit 72f683c

File tree

7 files changed

+251
-172
lines changed

7 files changed

+251
-172
lines changed

tmva/sofie/inc/TMVA/ROperator_Gather.hxx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ public:
7272
// empty shape Indices is a scalar value for the indices
7373
size_t indicesLength = ConvertShapeToLength(model.GetTensorShape(fNIndices));
7474
int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
75-
//flag index tensor as not writable (not sure this is needed since index tensor might be used in generated code)
76-
model.SetNotWritableInitializedTensor(fNIndices);
7775
// update indices data in case of negative dim values
7876
for (size_t i = 0; i < indicesLength; i++) {
7977
// move this at generation time?

tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ namespace SOFIE{
337337
auto m = (fAttrTransA ? fShapeA[dimA-1].GetVal() : fShapeA[dimA-2].GetVal());
338338
auto n = (fAttrTransB ? fShapeB[dimB-2].GetVal() : fShapeB[dimB-1].GetVal());
339339
auto k = (fAttrTransA ? fShapeA[dimA-2].GetVal() : fShapeA[dimA-1].GetVal());
340+
// size of A: if (trasposeA) is m*k else k*m
341+
// size of B n*k
340342
std::vector<Dim> sY = {fShapeY[dimY-2], fShapeY[dimY-1]};
341343
// extra dimensions in case of stacked MatMul
342344
std::vector<Dim> sA;
@@ -371,24 +373,49 @@ namespace SOFIE{
371373
// include MatMul case where we stack the Gemm operations
372374
// exclude case where we have only 1's in the additional dims
373375
bool doStackMul = dimY > 2 && ( fIsDynamic || std::stoi(lengthExtra) > 1);
376+
// compute input offset for stack multiplications
377+
std::string lengthExtra_A;
378+
std::string lengthExtra_B;
379+
std::string increment_A;
380+
std::string increment_B;
381+
382+
if (doStackMul) {
383+
std::vector<Dim> sA(fShapeA.begin(), fShapeA.begin()+dimA-2);
384+
std::vector<Dim> sB(fShapeB.begin(), fShapeB.begin()+dimB-2);
385+
std::vector<Dim> mA = {fShapeA[dimA-2], fShapeA[dimA-1]};
386+
std::vector<Dim> mB = {fShapeA[dimB-2], fShapeB[dimB-1]};
387+
lengthExtra_A = ConvertDimShapeToLength(sA);
388+
lengthExtra_B = ConvertDimShapeToLength(sB);
389+
// size of A performing matmul is m*k and n*k for B
390+
increment_A = ConvertDimShapeToLength(mA);
391+
increment_B = ConvertDimShapeToLength(mB);
392+
}
393+
bool extraA = (doStackMul && lengthExtra_A != "1");
394+
bool extraB = (doStackMul && lengthExtra_B != "1");
374395
if (doStackMul) {
375-
out << SP << "size_t " << opName << "_yoffset = 0;\n"; // needed if we stack the gemm operations
376-
out << SP << "for (int i = 0; i < " << lengthExtra << "; i++){\n";
396+
out << SP << "size_t " << opName << "_y_offset = 0;\n"; // needed if we stack the gemm operations
397+
if (extraA)
398+
out << SP << "size_t " << opName << "_A_offset = 0;\n";
399+
if (extraB)
400+
out << SP << "size_t " << opName << "_B_offset = 0;\n";
401+
out << SP << "for (size_t i = 0; i < " << lengthExtra << "; i++){\n";
377402
out << SP;
378403
}
379404

380405
if (fType == "float"){
381406

382407
out << SP << "TMVA::Experimental::SOFIE::Gemm_Call("
383408
<< "tensor_" << fNY;
384-
if (doStackMul) out << " + " << opName << "_yoffset";
409+
if (doStackMul) out << " + " << opName << "_y_offset";
385410
out << ", "
386411
<< (fAttrTransB ? "true, " : "false, ")
387412
<< (fAttrTransA ? "true, " : "false, ")
388413
<< n << ", " << m << ", " << k << ", ";
389-
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ",";
390-
out << "tensor_" << fNB << ", " << "tensor_" << fNA << ", ";
391-
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrBeta << ",";
414+
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ", tensor_" << fNB;
415+
if (extraB) out << " + " << opName << "_B_offset";
416+
out << ", tensor_" << fNA;
417+
if (extraA) out << " + " << opName << "_A_offset";
418+
out << ", " << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrBeta << ",";
392419
// in the case of bias
393420
if (!fNC.empty())
394421
out << "tensor_" << fNC;
@@ -404,7 +431,12 @@ namespace SOFIE{
404431
}
405432

406433
if (doStackMul) {
407-
out << SP << SP << opName << "_yoffset += " << lengthGemm << ";\n";
434+
out << SP << SP << opName << "_y_offset += " << lengthGemm << ";\n";
435+
if (lengthExtra_A != "1")
436+
out << SP << SP << opName << "_A_offset += " << increment_A << ";\n";
437+
if (lengthExtra_B != "1")
438+
out << SP << SP << opName << "_B_offset += " << increment_B << ";\n";
439+
408440
out << "}\n"; // end of loop on the stacked multiplications
409441
}
410442

0 commit comments

Comments
 (0)