Skip to content

Commit 084bd88

Browse files
authored
[tmva][sofie] fix compute strides for output tensor in topK and add input tensors for memory optimization in Gemm operator (#19023)
1 parent 0f5e8df commit 084bd88

File tree

2 files changed

+7
-24
lines changed

2 files changed

+7
-24
lines changed

tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ namespace SOFIE{
6464
fActivation = activation;
6565
fType = "float";
6666

67+
fInputTensorNames = {fNA, fNB, fNC};
6768
fOutputTensorNames = { fNY };
6869
}
6970

tmva/sofie/inc/TMVA/ROperator_TopK.hxx

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,10 @@ public:
8282
}
8383
// fK cannot be larger that axis dimension
8484
fK = std::min(fK, fShapeX[fAttrAxis]);
85-
// if(fK>fShapeX[fAttrAxis]){
86-
// throw
87-
// std::runtime_error("TMVA::SOFIE ONNX TopK op k = "+ std::to_string(fK) +" value exeeds value of tensor " +fNX+" of size "+fShapeX.size()+" at axis= "+std::to_string(fAttrAxis)+".");
88-
// }
89-
// fShapeX = model.GetTensorShape(fNX); // [ m x n x o x p ... ]
90-
// if(k[0]>=fShapeX.size()){
91-
// throw
92-
// std::runtime_error("TMVA::SOFIE ONNX TopK op k = "+ std::to_string(k[0]) +"value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
93-
// }
94-
// fShapeY.push_back(2);
95-
// for (auto i : fShapeX)
96-
// fShapeY.push_back(i); // [ 2 x m x n x o x p ... ]
97-
// size_t axis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
98-
// fShapeY[axis] = k[0]; // [ 2 x m x n x K x p ... ]
99-
fShapeY=ShapeInference({fShapeX,fShapeK})[0];
100-
101-
// for(int i=0;i<fShapeX.size();i++)
102-
// std::cout<<fShapeX[i]<<" ";
103-
// std::cout<<"\ny size -> "<<fShapeY.size()<<std::endl;
104-
10585

86+
fShapeY = ShapeInference({fShapeX, fShapeK})[0];
10687
model.AddIntermediateTensor(fNVal, model.GetTensorType(fNX), fShapeY);
88+
10789
// output indices should be an int64 tensor
10890
model.AddIntermediateTensor(fNInd, ETensorType::INT64, fShapeY);
10991
fType = ConvertTypeToString(model.GetTensorType(fNX));
@@ -121,7 +103,7 @@ public:
121103

122104
size_t length=ConvertShapeToLength(fShapeX);
123105
auto strideX = UTILITY::ComputeStrideFromShape(fShapeX);
124-
auto strideY = UTILITY::ComputeStrideFromShape(fShapeX);
106+
auto strideY = UTILITY::ComputeStrideFromShape(fShapeY);
125107
// we perform loop on dimension before sorted axis and after sorted axis
126108
size_t n_before = (axis>0) ? length/strideX[axis-1] : 1;
127109
size_t n_after = strideX[axis];
@@ -174,8 +156,8 @@ public:
174156
}
175157
};
176158

177-
} // nameSPace SOFIE
178-
} // nameSPace Experimental
179-
} // nameSPace TMVA
159+
} // namespace SOFIE
160+
} // namespace Experimental
161+
} // namespace TMVA
180162

181163
#endif // TMVA_SOFIE_ROPERATOR_TOPK

0 commit comments

Comments
 (0)