Skip to content

Commit e308acb

Browse files
committed
[tmva][sofie] Ass support for dynamic tensors in Expand, Concat, Reduce and SoftMax
Fix also dynamic support for other operators. Can parse now atlas GN2 network
1 parent 1c6fd66 commit e308acb

12 files changed

+321
-150
lines changed

tmva/sofie/inc/TMVA/RModel.hxx

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,6 @@ public:
5858

5959
const ETensorType &GetTensorType(std::string name) const;
6060

61-
// template<class T>
62-
// void GetTensorShape(const std::string & name, std::vector<T> & shape) {
63-
// if (TensorShape<T>::IsDim()) {
64-
// if (model.IsDynamicTensor(name)) {
65-
// shape = model.GetDynamicTensorShape(name);
66-
// } else {
67-
// intShape = model.GetTensorShape(name);
68-
// shape = ConvertShapeToDim(intShape);
69-
// }
70-
// } else {
71-
// shape = model.GetTensorShape(name);
72-
// }
73-
74-
// }
7561

7662
bool CheckIfTensorAlreadyExist(std::string tensor_name);
7763
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<Dim> shape);
@@ -180,8 +166,9 @@ protected:
180166
void GenerateSessionCode();
181167

182168
public:
183-
const std::vector<std::string> &GetInputTensorNames() const { return fInputTensorNames; }
184-
const std::vector<std::string> &GetOutputTensorNames() const { return fOutputTensorNames; }
169+
const std::vector<std::string> & GetInputTensorNames() const { return fInputTensorNames; }
170+
const std::vector<std::string> & GetOutputTensorNames() const { return fOutputTensorNames; }
171+
const std::vector<std::string> & GetDimShapeNames() const { return fDimShapeNames; }
185172

186173
void ReadInitializedTensorsFromFile(long);
187174
long WriteInitializedTensorsToFile(std::string filename = "");

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ public:
117117
auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB);
118118
fBroadcastFlag = ret.first;
119119
fShapeY = ret.second;
120-
std::cout << BinaryOperatorTrait<T, Op>::Name() << "checking for defined shapes " << fBroadcastFlag << " " << ConvertShapeToString(fShapeY) << std::endl;
120+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : checking for defined shapes " << fBroadcastFlag << " " << ConvertShapeToString(fShapeY) << std::endl;
121121
bool broadcast = ret.first > 0;
122122
if (broadcast) {
123123
// Y is the common shape of A and B
@@ -182,11 +182,18 @@ public:
182182
model.SetNotWritableInitializedTensor(nameA);
183183
model.SetNotWritableInitializedTensor(nameB);
184184
fIsOutputConstant = true;
185-
if (model.Verbose())
186-
std::cout << "Binary op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : "
187-
<< ConvertValuesToString(dataY) << std::endl;
185+
if (model.Verbose()) {
186+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
187+
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY
188+
<< " " << ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl;
189+
}
188190
} else {
189191
model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
192+
if (model.Verbose()) {
193+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
194+
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY
195+
<< " " << ConvertShapeToString(fShapeY) << std::endl;
196+
}
190197
}
191198
// we convert non-dim shapes to Dim shapes
192199
fDimShapeY = ConvertShapeToDim(fShapeY);
@@ -197,6 +204,38 @@ public:
197204
fBroadcastFlag = ret.first;
198205
fDimShapeY = ret.second;
199206
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : checking for Dim shapes " << fBroadcastFlag << " " << ConvertShapeToString(fDimShapeY) << std::endl;
207+
// case of all parametric shapes and MultiDirectionalBroadcastShape return the max of the 2
208+
// need to do before we declare the output tensor shape and the broadcasted ones
209+
if (ret.first & 4) {
210+
// check if one of the parameter is an input dimension
211+
// define function to find this
212+
auto IsInputDimParam = [&](const std::string & p) {
213+
auto inputNames = model.GetInputTensorNames();
214+
for (auto & input : inputNames) {
215+
for (auto & i_s : model.GetDimTensorShape(input)) {
216+
if (i_s.isParam && i_s.param == p) return true;
217+
}
218+
}
219+
return false;
220+
};
221+
for (size_t i = 0; i < fDimShapeY.size(); i++) {
222+
auto & s = fDimShapeY[i];
223+
if (s.isParam && s.param.find("std::max") != std::string::npos) {
224+
if (IsInputDimParam(fDimShapeA[i].param)) {
225+
// case dim is 1 we indicate that the input parameter is equal to 1
226+
if (fDimShapeA[i].dim != 1)
227+
s = fDimShapeA[i];
228+
else
229+
s = fDimShapeB[i];
230+
} else if (IsInputDimParam(fDimShapeB[i].param)) {
231+
if (fDimShapeB[i].dim != 1)
232+
s = fDimShapeB[i];
233+
else
234+
s = fDimShapeA[i];
235+
}
236+
}
237+
}
238+
}
200239
if (ret.first & 2) {
201240
// case we broadcast A
202241
fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
@@ -207,17 +246,12 @@ public:
207246
fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
208247
model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fDimShapeY);
209248
}
210-
// case of all parametric shapes and we know only at run time
211-
// we don't add in this case an intermediate tensor for broadcasting
212-
// if (ret.first == 4) {
213-
// for (auto & d : fDimShapeY) {
214-
// if (d.isParam && d.param.find("broadcast") != std::string::npos) {
215-
// d.param += fNY;
216-
// }
217-
// }
218-
// }
219-
// add output tensor
249+
220250
model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fDimShapeY);
251+
if (model.Verbose()) {
252+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << ConvertShapeToString(fDimShapeA) << " , "
253+
<< ConvertShapeToString(fDimShapeB) << " --> " << ConvertShapeToString(fDimShapeY) << std::endl;
254+
}
221255
}
222256
}
223257

tmva/sofie/inc/TMVA/ROperator_Concat.hxx

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
5757

5858
int concat_dim=0;
59+
// case of Concat (fNewAxis = 0) and not ConcatFromSequence
5960
if(fnewAxis == 0){
6061
for (size_t i = 0; i < inputs.size(); i++) {
6162
if (i > 0 && inputs[i].size() != inputs[i - 1].size())
@@ -76,6 +77,7 @@
7677
ret[0][fAxis] = concat_dim;
7778
}
7879
std::vector<int> stack;
80+
// case ConCatFromSequence
7981
if(fnewAxis == 1){
8082
for(size_t i = 0; i < inputs.size(); i++) {
8183
if (i > 0 && inputs[i].size() != inputs[i-1].size() )
@@ -99,58 +101,74 @@
99101
}
100102

101103
// get shape of output given inputs. It is going to be called after initialized
102-
std::vector<std::vector<Dim>> ShapeInference(const std::vector<std::vector<Dim>> & inputs) {
103-
std::vector<std::vector<Dim>> ret(1);
104+
std::vector<Dim> ShapeInference(const std::vector<std::vector<Dim>> & inputs, const RModel & model) {
105+
std::vector<Dim> ret(inputs[0].size());
104106
// treat negative axis case
105107
if (fAxis<0) {
106108
fAxis = inputs[0].size()+fAxis;
107109
}
108110
if (fAxis < 0 || fAxis >= (int) inputs[0].size())
109111
throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
110112

111-
std::string concat_dim;
112-
size_t i_concat_dim = 0;
113+
Dim concat_dim;
113114
if(fnewAxis == 0){
114115
for (size_t i = 0; i < inputs.size(); i++) {
115116
if (i > 0 && inputs[i].size() != inputs[i - 1].size())
116117
throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " + fInputs[i] + " : " +
117118
ConvertShapeToString(inputs[i]) + " and " + fInputs[i-1] + " : " + ConvertShapeToString(inputs[i - 1]));
118119
for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
119120
if ((int)iaxis == fAxis) {
120-
// support only non-params shape for the concatenation axis
121-
if (inputs[i][iaxis].isParam) {
122-
if (concat_dim.empty())
123-
concat_dim = inputs[i][iaxis].GetVal();
124-
else
125-
concat_dim += std::string("+ ") + inputs[i][iaxis].GetVal();
121+
// support both integer and params shape for the concatenation axis
122+
if (concat_dim.param.empty() && concat_dim.dim == 0)
123+
concat_dim = inputs[i][iaxis];
124+
else if (inputs[i][iaxis].isParam || concat_dim.isParam) {
125+
concat_dim =
126+
Dim{ concat_dim.GetVal() + std::string("+ ") + inputs[i][iaxis].GetVal(),
127+
static_cast<size_t>(-1)};
126128
} else {
127-
i_concat_dim += inputs[i][iaxis].dim;
128-
concat_dim = std::to_string(i_concat_dim);
129+
concat_dim = Dim { concat_dim.dim + inputs[i][iaxis].dim };
129130
}
130131
}
131-
// other dimensions must be the same
132-
else if (i > 0 && inputs[i][iaxis].GetVal() != inputs[i - 1][iaxis].GetVal())
132+
else if (i == 0) {
133+
ret[iaxis] = inputs[i][iaxis];
134+
}
135+
else if ((!inputs[i][iaxis].isParam && !ret[iaxis].isParam) && (inputs[i][iaxis].dim != ret[iaxis].dim)) {
133136
throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
134137
ConvertShapeToString(inputs[i]) + " and " +
135138
ConvertShapeToString(inputs[i - 1]));
139+
}
140+
else if (!inputs[i][iaxis].isParam && ret[iaxis].isParam){
141+
// if shape is not parametric use it
142+
ret[iaxis] = inputs[i][iaxis];
143+
}
144+
else if (inputs[i][iaxis].isParam && ret[iaxis].isParam) {
145+
// check which parameter is first in RModel list
146+
auto & dimNames = model.GetDimShapeNames();
147+
auto p1 = std::find(dimNames.begin(), dimNames.end(), inputs[i][iaxis].param);
148+
auto p2 = std::find(dimNames.begin(), dimNames.end(), ret[iaxis].param);
149+
if (p1 < p2) ret[iaxis] = inputs[i][iaxis];
150+
}
151+
136152
}
137153
}
138154

139-
// output shape
140-
ret[0] = inputs[0];
141-
// check if concat_dim is an integer
142-
// case like "2+n" can be converted to an integer so need to check the length
143-
size_t pos = 0;
144-
try {
145-
i_concat_dim = std::stoi(concat_dim, &pos);
146-
if (pos == concat_dim.length())
147-
ret[0][fAxis] = Dim{i_concat_dim}; // dimension is integer
148-
else
149-
ret[0][fAxis] = Dim{concat_dim};
150-
}
151-
catch (std::invalid_argument const& ex) {
152-
ret[0][fAxis] = Dim{concat_dim};
153-
}
155+
// output shape for concatenated axis
156+
ret[fAxis] = Dim{concat_dim};
157+
// //ret[0] = inputs[0];
158+
// // check if concat_dim is an integer
159+
// // case like "2+n" can be converted to an integer so need to check the length
160+
// size_t pos = 0;
161+
// try {
162+
// i_concat_dim = std::stoi(concat_dim, &pos);
163+
// if (pos == concat_dim.length())
164+
// ret[fAxis] = Dim{i_concat_dim}; // dimension is integer
165+
// else {
166+
// // check if a composite expression
167+
// ret[fAxis] = Dim{concat_dim};
168+
// }
169+
// catch (std::invalid_argument const& ex) {
170+
171+
// }
154172

155173
}
156174
// case of stacking (not supported yet)
@@ -170,7 +188,7 @@
170188
}
171189
fInputShapes.push_back(model.GetDimTensorShape(it));
172190
}
173-
fOutputShape = ShapeInference(fInputShapes)[0];
191+
fOutputShape = ShapeInference(fInputShapes, model);
174192
if (model.Verbose())
175193
std::cout << "Output of concat operator has shape " << ConvertDimShapeToString(fOutputShape) << std::endl;
176194

0 commit comments

Comments
 (0)