@@ -55,6 +55,7 @@ template<typename T, EBasicBinaryOperator Op>
5555class ROperator_BasicBinary final : public ROperator{
5656private:
5757
58+ int fBroadcastFlag = 0 ;
5859 std::string fNA ;
5960 std::string fNB ;
6061 std::string fNBroadcastedA ;
@@ -114,12 +115,14 @@ public:
114115 // case of known shapes
115116 if (!fShapeA .empty () && !fShapeB .empty ()) {
116117 auto ret = UTILITY::MultidirectionalBroadcastShape (fShapeA , fShapeB );
118+ fBroadcastFlag = ret.first ;
117119 fShapeY = ret.second ;
120+ std::cout << BinaryOperatorTrait<T, Op>::Name () << " checking for defined shapes " << fBroadcastFlag << " " << ConvertShapeToString (fShapeY ) << std::endl;
118121 bool broadcast = ret.first > 0 ;
119122 if (broadcast) {
120123 // Y is the common shape of A and B
121- bool broadcastA = ret.first > 1 ;
122- bool broadcastB = ret.first == 1 || ret. first == 3 ;
124+ bool broadcastA = ret.first & 2 ;
125+ bool broadcastB = ret.first & 1 ;
123126 // Broadcast A to Y
124127 if (broadcastA) {
125128 fNBroadcastedA = " Broadcasted" + fNA + " to" + fNY ;
@@ -191,17 +194,28 @@ public:
191194 else {
192195 // case A or B have dynamic shapes. We need to broadcast if shape are not same
193196 auto ret = UTILITY::MultidirectionalBroadcastShape (fDimShapeA , fDimShapeB );
197+ fBroadcastFlag = ret.first ;
194198 fDimShapeY = ret.second ;
195- if (ret.first > 1 ) {
199+ std::cout << BinaryOperatorTrait<T, Op>::Name () << " : checking for Dim shapes " << fBroadcastFlag << " " << ConvertShapeToString (fDimShapeY ) << std::endl;
200+ if (ret.first & 2 ) {
196201 // case we broadcast A
197202 fNBroadcastedA = " Broadcasted" + fNA + " to" + fNY ;
198203 model.AddIntermediateTensor (fNBroadcastedA , model.GetTensorType (fNA ), fDimShapeY );
199204 }
200- if (ret.first == 1 || ret. first == 3 ) {
205+ if (ret.first & 1 ) {
201206 // case we broadcast B
202207 fNBroadcastedB = " Broadcasted" + fNB + " to" + fNY ;
203208 model.AddIntermediateTensor (fNBroadcastedB , model.GetTensorType (fNB ), fDimShapeY );
204209 }
210+ // case of all parametric shapes and we know only at run time
211+ // we don't add in this case an intermediate tensor for broadcasting
212+ // if (ret.first == 4) {
213+ // for (auto & d : fDimShapeY) {
214+ // if (d.isParam && d.param.find("broadcast") != std::string::npos) {
215+ // d.param += fNY;
216+ // }
217+ // }
218+ // }
205219 // add output tensor
206220 model.AddIntermediateTensor (fNY , model.GetTensorType (fNA ), fDimShapeY );
207221 }
@@ -212,11 +226,11 @@ public:
212226 return out.str ();
213227 }
214228
215- std::string Generate (std::string OpName ) override {
229+ std::string Generate (std::string opName ) override {
216230
217231 if (fIsOutputConstant ) return " " ;
218232
219- OpName = " op_" + OpName ;
233+ opName = " op_" + opName ;
220234
221235 if (fDimShapeY .empty ()) {
222236 throw std::runtime_error (" TMVA SOFIE Binary Op called to Generate without being initialized first" );
@@ -225,21 +239,55 @@ public:
225239 out << SP << " \n //------ " << BinaryOperatorTrait<T,Op>::Name () << " \n " ;
226240 auto length = ConvertDimShapeToLength (fDimShapeY );
227241 std::string typeName = TensorType<T>::Name ();
242+ // we need to check if we can broadcast (case flag has bit 4 set)
243+ if (fBroadcastFlag & 4 ) {
244+ // need to check if shapes are the same
245+ auto lengthA = ConvertDimShapeToLength (fDimShapeA );
246+ auto lengthB = ConvertDimShapeToLength (fDimShapeB );
247+ out << SP << " if (" << lengthA << " !=" << lengthB << " ) {\n " ;
248+ // check if A->B or B->A
249+ // bool broadcastable = true;
250+ for (size_t i = 0 ; i < fDimShapeY .size (); i++) {
251+ if (fBroadcastFlag & 5 && fDimShapeY [i] == fDimShapeA [i] && fDimShapeA [i].dim > 1 && fDimShapeB [i].isParam ) {
252+ // B->A B[i] needs to be 1
253+ out << SP << SP << " if (" << fDimShapeB [i] << " != 1)\n " ;
254+ out << SP << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast B->A in operator "
255+ << opName << " \" );\n " ;
256+ }
257+ if (fBroadcastFlag & 6 && fDimShapeY [i] == fDimShapeB [i] && fDimShapeB [i].dim > 1 && fDimShapeA [i].isParam ) {
258+ // A-> B A[i] needs to be 1
259+ out << SP << SP << " if (" << fDimShapeA [i] << " != 1)\n " ;
260+ out << SP << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast A->B in operator "
261+ << opName << " \" );\n " ;
262+ }
263+ else if (fDimShapeA [i].isParam && fDimShapeB [i].isParam ) {
264+ // both shapes are parametric and we broadcast to maximum
265+ // we allocate here output vector
266+ out << SP << SP << " if (" << fDimShapeA [i] << " != " << fDimShapeB [i] << " && ("
267+ << fDimShapeA [i] << " != 1 || " << fDimShapeB [i] << " != 1))\n " ;
268+ out << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast shapes in operator "
269+ << opName << " \" );\n " ;
270+ }
271+ }
272+ } else {
273+ out << SP << " {\n " ;
274+ }
228275 // Broadcast A if it's uninitialized
229276 // use broadcasting function where we pass an already allocated tensor to minimize memory allocations
230277 if (!fNBroadcastedA .empty ()) {
231- out << SP << " // Broadcasting uninitialized tensor " << fNA << " \n " ;
232- out << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << " >(tensor_" << fNA << " , "
278+ out << SP << SP << " // Broadcasting uninitialized tensor " << fNA << " \n " ;
279+ out << SP << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << " >(tensor_" << fNA << " , "
233280 << ConvertDimShapeToString (fDimShapeA ) << " , " << ConvertDimShapeToString (fDimShapeY )
234281 << " , fTensor_" << fNBroadcastedA << " );\n " ;
235282 }
236283 // Broadcast B if it's uninitialized
237284 if (!fNBroadcastedB .empty ()) {
238- out << SP << " // Broadcasting uninitialized tensor " << fNB << " \n " ;
239- out << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << " >(tensor_" << fNB << " , "
285+ out << SP << SP << " // Broadcasting uninitialized tensor " << fNB << " \n " ;
286+ out << SP << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << " >(tensor_" << fNB << " , "
240287 << ConvertDimShapeToString (fDimShapeB ) << " , " << ConvertDimShapeToString (fDimShapeY )
241288 << " , fTensor_" << fNBroadcastedB << " );\n " ;
242289 }
290+ out << SP << " }\n " ; // end if on broadcasting
243291 const std::string& nameA = fNBroadcastedA .empty ()? fNA : fNBroadcastedA ;
244292 const std::string& nameB = fNBroadcastedB .empty ()? fNB : fNBroadcastedB ;
245293 out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
0 commit comments