77
88#include < sstream>
99
10- namespace TMVA {
11- namespace Experimental {
12- namespace SOFIE {
10+ namespace TMVA {
11+ namespace Experimental {
12+ namespace SOFIE {
1313
14- enum EBasicBinaryOperator { Add, Sub, Mul, Div, Pow };
14+ enum EBasicBinaryOperator {
15+ Add,
16+ Sub,
17+ Mul,
18+ Div,
19+ Pow
20+ };
1521
1622template <typename T, EBasicBinaryOperator Op1>
1723struct BinaryOperatorTrait {};
1824
1925template <typename T>
2026struct BinaryOperatorTrait <T, Add> {
2127 static const std::string Name () { return " Add" ; }
22- static std::string Op (const std::string & t1, const std::string t2) { return t1 + " + " + t2; }
23- static T Func (T t1, T t2) {return t1 + t2;}
28+ static std::string Op (const std::string &t1, const std::string t2) { return t1 + " + " + t2; }
29+ static T Func (T t1, T t2) { return t1 + t2; }
2430};
2531
2632template <typename T>
2733struct BinaryOperatorTrait <T, Sub> {
2834 static const std::string Name () { return " Sub" ; }
29- static std::string Op (const std::string & t1, const std::string t2) { return t1 + " - " + t2; }
30- static T Func (T t1, T t2) { return t1 - t2;}
35+ static std::string Op (const std::string &t1, const std::string t2) { return t1 + " - " + t2; }
36+ static T Func (T t1, T t2) { return t1 - t2; }
3137};
3238
3339template <typename T>
3440struct BinaryOperatorTrait <T, Mul> {
3541 static const std::string Name () { return " Mul" ; }
36- static std::string Op (const std::string & t1, const std::string t2) { return t1 + " * " + t2; }
37- static T Func (T t1, T t2) { return t1 * t2;}
42+ static std::string Op (const std::string &t1, const std::string t2) { return t1 + " * " + t2; }
43+ static T Func (T t1, T t2) { return t1 * t2; }
3844};
3945
4046template <typename T>
4147struct BinaryOperatorTrait <T, Div> {
4248 static const std::string Name () { return " Div" ; }
43- static std::string Op (const std::string & t1, const std::string t2) { return t1 + " / " + t2; }
44- static T Func (T t1, T t2) { return t1/ t2;}
49+ static std::string Op (const std::string &t1, const std::string t2) { return t1 + " / " + t2; }
50+ static T Func (T t1, T t2) { return t1 / t2; }
4551};
4652
4753template <typename T>
4854struct BinaryOperatorTrait <T, Pow> {
4955 static const std::string Name () { return " Pow" ; }
50- static std::string Op (const std::string & t1, const std::string t2) { return " std::pow(" + t1 + " ," + t2 + " )" ; }
51- static T Func (T t1, T t2) { return std::pow (t1,t2);}
56+ static std::string Op (const std::string &t1, const std::string t2) { return " std::pow(" + t1 + " ," + t2 + " )" ; }
57+ static T Func (T t1, T t2) { return std::pow (t1, t2); }
5258};
5359
54- template <typename T, EBasicBinaryOperator Op>
55- class ROperator_BasicBinary final : public ROperator{
60+ template <typename T, EBasicBinaryOperator Op>
61+ class ROperator_BasicBinary final : public ROperator {
5662private:
57-
5863 int fBroadcastFlag = 0 ;
5964 std::string fNA ;
6065 std::string fNB ;
@@ -71,28 +76,29 @@ private:
7176 std::vector<Dim> fDimShapeY ;
7277
7378public:
74- ROperator_BasicBinary (){}
75- ROperator_BasicBinary (std::string nameA, std::string nameB, std::string nameY):
76- fNA (UTILITY::Clean_name(nameA)), fNB (UTILITY::Clean_name(nameB)), fNY (UTILITY::Clean_name(nameY)){
77- fInputTensorNames = { fNA , fNB };
78- fOutputTensorNames = { fNY };
79- }
79+ ROperator_BasicBinary () {}
80+ ROperator_BasicBinary (std::string nameA, std::string nameB, std::string nameY)
81+ : fNA (UTILITY::Clean_name(nameA)), fNB (UTILITY::Clean_name(nameB)), fNY (UTILITY::Clean_name(nameY))
82+ {
83+ fInputTensorNames = {fNA , fNB };
84+ fOutputTensorNames = {fNY };
85+ }
8086
8187 // type of output given input
82- std::vector<ETensorType> TypeInference (std::vector<ETensorType> input) override {
83- return input;
84- }
88+ std::vector<ETensorType> TypeInference (std::vector<ETensorType> input) override { return input; }
8589
8690 // shape of output tensors given input tensors
87- std::vector<std::vector<size_t >> ShapeInference (std::vector<std::vector<size_t >> input) override {
91+ std::vector<std::vector<size_t >> ShapeInference (std::vector<std::vector<size_t >> input) override
92+ {
8893 // assume now inputs have same shape (no broadcasting)
8994 auto ret = std::vector<std::vector<size_t >>(1 , input[0 ]); // return vector size 1 with first input
9095 return ret;
9196 }
92-
93- void Initialize (RModel& model) override {
97+
98+ void Initialize (RModel &model) override
99+ {
94100 // input must be a graph input, or already initialized intermediate tensor
95- if (!model.CheckIfTensorAlreadyExist (fNA )){
101+ if (!model.CheckIfTensorAlreadyExist (fNA )) {
96102 throw std::runtime_error (std::string (" TMVA SOFIE Binary Op Input Tensor " ) + fNA + " is not found in model" );
97103 }
98104 if (!model.CheckIfTensorAlreadyExist (fNB )) {
@@ -113,10 +119,12 @@ public:
113119 fShapeB = model.GetTensorShape (fNB );
114120 fDimShapeB = ConvertShapeToDim (fShapeB );
115121 }
116- if (dynamicInputs & 1 && model.Verbose () )
117- std::cout << BinaryOperatorTrait<T, Op>::Name () << " : input " << fNA << " is dynamic " << ConvertShapeToString (fDimShapeA ) << " " ;
122+ if (dynamicInputs & 1 && model.Verbose ())
123+ std::cout << BinaryOperatorTrait<T, Op>::Name () << " : input " << fNA << " is dynamic "
124+ << ConvertShapeToString (fDimShapeA ) << " " ;
118125 if (dynamicInputs & 2 && model.Verbose ())
119- std::cout << BinaryOperatorTrait<T, Op>::Name () << " : input " << fNB << " is dynamic " << ConvertShapeToString (fDimShapeB ) << " " ;
126+ std::cout << BinaryOperatorTrait<T, Op>::Name () << " : input " << fNB << " is dynamic "
127+ << ConvertShapeToString (fDimShapeB ) << " " ;
120128 std::cout << std::endl;
121129 // check if need to broadcast at initialization time if shapes are known and different
122130 // (we could broadcast the tensor tensor to maximum values of dynamic shapes - to be done)
@@ -125,7 +133,7 @@ public:
125133 auto ret = UTILITY::MultidirectionalBroadcastShape (fShapeA , fShapeB );
126134 fBroadcastFlag = ret.first ;
127135 fShapeY = ret.second ;
128- bool broadcast = ret.first > 0 ;
136+ bool broadcast = ret.first > 0 ;
129137 if (broadcast) {
130138 // Y is the common shape of A and B
131139 bool broadcastA = ret.first & 2 ;
@@ -186,16 +194,16 @@ public:
186194 model.SetNotWritableInitializedTensor (nameB);
187195 fIsOutputConstant = true ;
188196 if (model.Verbose ()) {
189- std::cout << BinaryOperatorTrait<T, Op>::Name () << " : " << fNA << " " << ConvertShapeToString (fShapeA )
190- << " , " << fNB << " " << ConvertShapeToString (fShapeB ) << " ---> " << fNY
191- << " " << ConvertShapeToString (fShapeY ) << " : " << ConvertValuesToString (dataY) << std::endl;
197+ std::cout << BinaryOperatorTrait<T, Op>::Name () << " : " << fNA << " " << ConvertShapeToString (fShapeA )
198+ << " , " << fNB << " " << ConvertShapeToString (fShapeB ) << " ---> " << fNY << " "
199+ << ConvertShapeToString (fShapeY ) << " : " << ConvertValuesToString (dataY) << std::endl;
192200 }
193201 } else {
194202 model.AddIntermediateTensor (fNY , model.GetTensorType (fNA ), fShapeY );
195203 if (model.Verbose ()) {
196- std::cout << BinaryOperatorTrait<T, Op>::Name () << " : " << fNA << " " << ConvertShapeToString (fShapeA )
197- << " , " << fNB << " " << ConvertShapeToString (fShapeB ) << " ---> " << fNY
198- << " " << ConvertShapeToString (fShapeY ) << std::endl;
204+ std::cout << BinaryOperatorTrait<T, Op>::Name () << " : " << fNA << " " << ConvertShapeToString (fShapeA )
205+ << " , " << fNB << " " << ConvertShapeToString (fShapeB ) << " ---> " << fNY << " "
206+ << ConvertShapeToString (fShapeY ) << std::endl;
199207 }
200208 }
201209 // we convert non-dim shapes to Dim shapes
@@ -211,17 +219,18 @@ public:
211219 if (ret.first & 4 ) {
212220 // check if one of the parameter is an input dimension
213221 // define function to find this
214- auto IsInputDimParam = [&](const std::string & p) {
222+ auto IsInputDimParam = [&](const std::string &p) {
215223 auto inputNames = model.GetInputTensorNames ();
216- for (auto & input : inputNames) {
217- for (auto & i_s : model.GetDimTensorShape (input)) {
218- if (i_s.isParam && i_s.param == p) return true ;
224+ for (auto &input : inputNames) {
225+ for (auto &i_s : model.GetDimTensorShape (input)) {
226+ if (i_s.isParam && i_s.param == p)
227+ return true ;
219228 }
220229 }
221230 return false ;
222231 };
223232 for (size_t i = 0 ; i < fDimShapeY .size (); i++) {
224- auto & s = fDimShapeY [i];
233+ auto &s = fDimShapeY [i];
225234 if (s.isParam && s.param .find (" std::max" ) != std::string::npos) {
226235 if (IsInputDimParam (fDimShapeA [i].param )) {
227236 // case dim is 1 we indicate that the input parameter is equal to 1
@@ -238,7 +247,7 @@ public:
238247 }
239248 }
240249 }
241-
250+
242251 model.AddIntermediateTensor (fNY , model.GetTensorType (fNA ), fDimShapeY );
243252 if (model.Verbose ()) {
244253 std::cout << BinaryOperatorTrait<T, Op>::Name () << " : " << ConvertShapeToString (fDimShapeA ) << " , "
@@ -247,22 +256,25 @@ public:
247256 }
248257 }
249258
250- std::string GenerateInitCode () override {
259+ std::string GenerateInitCode () override
260+ {
251261 std::stringstream out;
252262 return out.str ();
253263 }
254264
255- std::string Generate (std::string opName) override {
265+ std::string Generate (std::string opName) override
266+ {
256267
257- if (fIsOutputConstant ) return " " ;
268+ if (fIsOutputConstant )
269+ return " " ;
258270
259271 opName = " op_" + opName;
260272
261273 if (fDimShapeY .empty ()) {
262274 throw std::runtime_error (" TMVA SOFIE Binary Op called to Generate without being initialized first" );
263275 }
264276 std::stringstream out;
265- out << SP << " \n //------ " << BinaryOperatorTrait<T,Op>::Name () << " \n " ;
277+ out << SP << " \n //------ " << BinaryOperatorTrait<T, Op>::Name () << " \n " ;
266278 auto length = ConvertDimShapeToLength (fDimShapeY );
267279 std::string typeName = TensorType<T>::Name ();
268280
@@ -273,82 +285,91 @@ public:
273285 auto lengthB = ConvertDimShapeToLength (fDimShapeB );
274286 out << SP << " if (" << lengthA << " !=" << lengthB << " ) {\n " ;
275287 // check if A->B or B->A
276- // bool broadcastable = true;
288+ // bool broadcastable = true;
277289 for (size_t i = 0 ; i < fDimShapeY .size (); i++) {
278- if (fBroadcastFlag & 5 && fDimShapeY [i] == fDimShapeA [i] && fDimShapeA [i].dim > 1 && fDimShapeB [i].isParam ) {
290+ if (fBroadcastFlag & 5 && fDimShapeY [i] == fDimShapeA [i] && fDimShapeA [i].dim > 1 &&
291+ fDimShapeB [i].isParam ) {
279292 // B->A B[i] needs to be 1
280293 out << SP << SP << " if (" << fDimShapeB [i] << " != 1)\n " ;
281294 out << SP << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast B->A in operator "
282- << opName << " \" );\n " ;
295+ << opName << " \" );\n " ;
283296 }
284- if (fBroadcastFlag & 6 && fDimShapeY [i] == fDimShapeB [i] && fDimShapeB [i].dim > 1 && fDimShapeA [i].isParam ) {
285- // A-> B A[i] needs to be 1
297+ if (fBroadcastFlag & 6 && fDimShapeY [i] == fDimShapeB [i] && fDimShapeB [i].dim > 1 &&
298+ fDimShapeA [i].isParam ) {
299+ // A-> B A[i] needs to be 1
286300 out << SP << SP << " if (" << fDimShapeA [i] << " != 1)\n " ;
287301 out << SP << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast A->B in operator "
288- << opName << " \" );\n " ;
289- }
290- else if (fDimShapeA [i].isParam && fDimShapeB [i].isParam ) {
302+ << opName << " \" );\n " ;
303+ } else if (fDimShapeA [i].isParam && fDimShapeB [i].isParam ) {
291304 // both shapes are parametric and we broadcast to maximum
292305 // we allocate here output vector
293- out << SP << SP << " if (" << fDimShapeA [i] << " != " << fDimShapeB [i] << " && ("
294- << fDimShapeA [i] << " != 1 || " << fDimShapeB [i] << " != 1))\n " ;
295- out << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast shapes in operator "
296- << opName << " \" );\n " ;
306+ out << SP << SP << " if (" << fDimShapeA [i] << " != " << fDimShapeB [i] << " && (" << fDimShapeA [i]
307+ << " != 1 || " << fDimShapeB [i] << " != 1))\n " ;
308+ out << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast shapes in operator " << opName
309+ << " \" );\n " ;
297310 }
298311 }
299312 }
300-
313+
301314 auto stridesA = UTILITY::ComputeStrideFromShape (fShapeA );
302315 auto stridesB = UTILITY::ComputeStrideFromShape (fShapeB );
303316 auto stridesY = UTILITY::ComputeStrideFromShape (fShapeY );
304317
305318 std::string compute_idx_A, compute_idx_B, compute_idx_Y;
306- if (std::all_of (fShapeA .begin (), fShapeA .end (), [](size_t x) { return x == 1 ; })){
319+ if (std::all_of (fShapeA .begin (), fShapeA .end (), [](size_t x) { return x == 1 ; })) {
307320 compute_idx_A = " 0" ;
308321 } else {
309- for (size_t i = 0 ; i<fShapeA .size (); ++i){
310- if (fShapeA [i]==1 ) continue ;
311- compute_idx_A += " idx_" +fNY +std::to_string (i+(fShapeY .size ()-fShapeA .size ()))+" * " +stridesA[i]+" +" ;
322+ for (size_t i = 0 ; i < fShapeA .size (); ++i) {
323+ if (fShapeA [i] == 1 )
324+ continue ;
325+ compute_idx_A +=
326+ " idx_" + fNY + std::to_string (i + (fShapeY .size () - fShapeA .size ())) + " * " + stridesA[i] + " +" ;
312327 }
313328 compute_idx_A.pop_back ();
314329 }
315- if (std::all_of (fShapeB .begin (), fShapeB .end (), [](size_t x) { return x == 1 ; })){
330+ if (std::all_of (fShapeB .begin (), fShapeB .end (), [](size_t x) { return x == 1 ; })) {
316331 compute_idx_B = " 0" ;
317332 } else {
318- for (size_t i = 0 ; i<fShapeB .size (); ++i){
319- if (fShapeB [i]==1 ) continue ;
320- compute_idx_B += " idx_" +fNY +std::to_string (i+(fShapeY .size ()-fShapeB .size ()))+" * " +stridesB[i]+" +" ;
333+ for (size_t i = 0 ; i < fShapeB .size (); ++i) {
334+ if (fShapeB [i] == 1 )
335+ continue ;
336+ compute_idx_B +=
337+ " idx_" + fNY + std::to_string (i + (fShapeY .size () - fShapeB .size ())) + " * " + stridesB[i] + " +" ;
321338 }
322339 compute_idx_B.pop_back ();
323340 }
324- for (size_t i=0 ; i<fShapeY .size (); ++i){
325- if (fShapeY [i]!=1 ){
326- out<<std::string (i + 1 , ' ' )<<" for(size_t idx_" <<fNY <<i<<" =0; idx_" <<fNY <<i<<" <" <<fShapeY [i]<<" ; ++idx_" <<fNY <<i<<" ){\n " ;
327- compute_idx_Y += " idx_" +fNY +std::to_string (i)+" *" +stridesY[i]+" +" ;
341+ for (size_t i = 0 ; i < fShapeY .size (); ++i) {
342+ if (fShapeY [i] != 1 ) {
343+ out << std::string (i + 1 , ' ' ) << " for(size_t idx_" << fNY << i << " =0; idx_" << fNY << i << " <"
344+ << fShapeY [i] << " ; ++idx_" << fNY << i << " ){\n " ;
345+ compute_idx_Y += " idx_" + fNY + std::to_string (i) + " *" + stridesY[i] + " +" ;
328346 }
329347 }
330348 compute_idx_Y.pop_back ();
331- out << SP << SP << " tensor_" << fNY <<" [" <<compute_idx_Y<<" ] = " <<BinaryOperatorTrait<T,Op>::Op (" tensor_" + fNA + " [" +compute_idx_A+" ]" , " tensor_" + fNB + " [" +compute_idx_B+" ]" )<<" ;\n " ;
332- for (size_t i=0 ; i<fShapeY .size (); ++i){
333- if (fShapeY [i]!=1 ){
334- out<<std::string (fShapeY .size ()-i+1 , ' ' )<<" }\n " ;
349+ out << SP << SP << " tensor_" << fNY << " [" << compute_idx_Y << " ] = "
350+ << BinaryOperatorTrait<T, Op>::Op (" tensor_" + fNA + " [" + compute_idx_A + " ]" ,
351+ " tensor_" + fNB + " [" + compute_idx_B + " ]" )
352+ << " ;\n " ;
353+ for (size_t i = 0 ; i < fShapeY .size (); ++i) {
354+ if (fShapeY [i] != 1 ) {
355+ out << std::string (fShapeY .size () - i + 1 , ' ' ) << " }\n " ;
335356 }
336357 }
337358 return out.str ();
338359 }
339360
340- std::vector<std::string> GetStdLibs () override {
361+ std::vector<std::string> GetStdLibs () override
362+ {
341363 if (Op == EBasicBinaryOperator::Pow) {
342- return { std::string (" cmath" ) };
364+ return {std::string (" cmath" )};
343365 } else {
344366 return {};
345367 }
346368 }
347369};
348370
349- }// SOFIE
350- }// Experimental
351- }// TMVA
352-
371+ } // namespace SOFIE
372+ } // namespace Experimental
373+ } // namespace TMVA
353374
354- #endif // TMVA_SOFIE_ROperator_BasicBinary
375+ #endif // TMVA_SOFIE_ROperator_BasicBinary
0 commit comments