@@ -129,33 +129,37 @@ public:
129
129
// check if need to broadcast at initialization time if shapes are known and different
130
130
// (we could broadcast the tensor tensor to maximum values of dynamic shapes - to be done)
131
131
// case of known shapes
132
+ // if shapes are known find the output shape from broadcasting
132
133
if (dynamicInputs == 0 ) {
133
134
auto ret = UTILITY::MultidirectionalBroadcastShape (fShapeA , fShapeB );
134
135
fBroadcastFlag = ret.first ;
135
136
fShapeY = ret.second ;
136
- bool broadcast = ret. first > 0 ;
137
- if ( broadcast) {
138
- // Y is the common shape of A and B
139
- bool broadcastA = ret. first & 2 ;
140
- bool broadcastB = ret. first & 1 ;
141
- // Broadcast A to Y
142
- if (broadcastA) {
143
- fNBroadcastedA = " Broadcasted " + fNA + " to " + fNY ;
144
- if (model. IsConstantTensor ( fNA )) {
137
+ if (model. IsConstantTensor ( fNA ) && model. IsConstantTensor ( fNB )) {
138
+ bool broadcast = fBroadcastFlag > 0 ;
139
+ if (broadcast) {
140
+ // Y is the common shape of A and B
141
+ bool broadcastA = fBroadcastFlag & 2 ;
142
+ bool broadcastB = fBroadcastFlag & 1 ;
143
+ // Broadcast A to Y
144
+ if (broadcastA) {
145
+ fNBroadcastedA = " Broadcasted " + fNA + " to " + fNY ;
145
146
auto data = model.GetInitializedTensorData (fNA );
146
147
std::shared_ptr<void > broadcastedData (
147
148
UTILITY::UnidirectionalBroadcast<T>(static_cast <T *>(data.get ()), fShapeA , fShapeY ),
148
149
std::default_delete<T[]>());
150
+ if (model.Verbose ())
151
+ std::cout << " broadcasted data A " << ConvertShapeToString (fShapeY ) << " : "
152
+ << ConvertValuesToString (ConvertShapeToLength (fShapeY ),
153
+ static_cast <T *>(broadcastedData.get ()))
154
+ << std::endl;
149
155
// Update the data and the shape of A
150
156
model.AddConstantTensor (fNBroadcastedA , model.GetTensorType (fNA ), fShapeY , broadcastedData);
151
157
fShapeA = fShapeY ;
152
158
fDimShapeA = ConvertShapeToDim (fShapeA );
153
159
}
154
- }
155
- // Broadcast B to Y
156
- if (broadcastB) {
157
- fNBroadcastedB = " Broadcasted" + fNB + " to" + fNY ;
158
- if (model.IsConstantTensor (fNB )) {
160
+ // Broadcast B to Y
161
+ if (broadcastB) {
162
+ fNBroadcastedB = " Broadcasted" + fNB + " to" + fNY ;
159
163
auto data = model.GetInitializedTensorData (fNB );
160
164
if (model.Verbose ())
161
165
std::cout << " data B " << ConvertShapeToString (fShapeB ) << " : "
@@ -174,12 +178,11 @@ public:
174
178
fShapeB = fShapeY ;
175
179
fDimShapeB = ConvertShapeToDim (fShapeB );
176
180
}
181
+ } else {
182
+ fShapeY = fShapeA ;
177
183
}
178
- } else {
179
- fShapeY = fShapeA ;
180
- }
181
- // check case of constant output (if all inputs are defined)
182
- if (model.IsConstantTensor (fNA ) && model.IsConstantTensor (fNB )) {
184
+ // tensors are constant: perform here the binary operation
185
+
183
186
const std::string &nameA = fNBroadcastedA .empty () ? fNA : fNBroadcastedA ;
184
187
const std::string &nameB = fNBroadcastedB .empty () ? fNB : fNBroadcastedB ;
185
188
auto dataA = static_cast <T *>(model.GetInitializedTensorData (nameA).get ());
@@ -189,7 +192,7 @@ public:
189
192
dataY[i] = BinaryOperatorTrait<T, Op>::Func (dataA[i], dataB[i]);
190
193
}
191
194
model.AddConstantTensor <T>(fNY , fShapeY , dataY.data ());
192
- // flag tensors to not be written in a fil
195
+ // flag tensors to not be written in the weight file
193
196
model.SetNotWritableInitializedTensor (nameA);
194
197
model.SetNotWritableInitializedTensor (nameB);
195
198
fIsOutputConstant = true ;
@@ -199,17 +202,17 @@ public:
199
202
<< ConvertShapeToString (fShapeY ) << " : " << ConvertValuesToString (dataY) << std::endl;
200
203
}
201
204
} else {
205
+ // case of defined and non-constant tensors
202
206
model.AddIntermediateTensor (fNY , model.GetTensorType (fNA ), fShapeY );
203
207
if (model.Verbose ()) {
204
208
std::cout << BinaryOperatorTrait<T, Op>::Name () << " : " << fNA << " " << ConvertShapeToString (fShapeA )
205
209
<< " , " << fNB << " " << ConvertShapeToString (fShapeB ) << " ---> " << fNY << " "
206
210
<< ConvertShapeToString (fShapeY ) << std::endl;
207
211
}
212
+ // we convert non-dim shapes to Dim shapes
213
+ fDimShapeY = ConvertShapeToDim (fShapeY );
208
214
}
209
- // we convert non-dim shapes to Dim shapes
210
- fDimShapeY = ConvertShapeToDim (fShapeY );
211
- } // endif of non-parametric shapes
212
- else {
215
+ } else {
213
216
// case A or B have dynamic shapes. We need to broadcast if shape are not same
214
217
auto ret = UTILITY::MultidirectionalBroadcastShape (fDimShapeA , fDimShapeB );
215
218
fBroadcastFlag = ret.first ;
@@ -274,7 +277,8 @@ public:
274
277
throw std::runtime_error (" TMVA SOFIE Binary Op called to Generate without being initialized first" );
275
278
}
276
279
std::stringstream out;
277
- out << SP << " \n //------ " << BinaryOperatorTrait<T, Op>::Name () << " \n " ;
280
+ out << SP << " \n //------ " << opName << " " << BinaryOperatorTrait<T, Op>::Name () << " --> "
281
+ << ConvertDimShapeToString (fDimShapeY ) << " \n " ;
278
282
auto length = ConvertDimShapeToLength (fDimShapeY );
279
283
std::string typeName = TensorType<T>::Name ();
280
284
@@ -323,7 +327,7 @@ public:
323
327
if (fShapeA [i] == 1 )
324
328
continue ;
325
329
compute_idx_A +=
326
- " idx_" + fNY + std::to_string (i + (fShapeY .size () - fShapeA .size ())) + " * " + stridesA[i] + " +" ;
330
+ " idx_" + std::to_string (i + (fShapeY .size () - fShapeA .size ())) + " * " + stridesA[i] + " +" ;
327
331
}
328
332
compute_idx_A.pop_back ();
329
333
}
@@ -334,15 +338,15 @@ public:
334
338
if (fShapeB [i] == 1 )
335
339
continue ;
336
340
compute_idx_B +=
337
- " idx_" + fNY + std::to_string (i + (fShapeY .size () - fShapeB .size ())) + " * " + stridesB[i] + " +" ;
341
+ " idx_" + std::to_string (i + (fShapeY .size () - fShapeB .size ())) + " * " + stridesB[i] + " +" ;
338
342
}
339
343
compute_idx_B.pop_back ();
340
344
}
341
345
for (size_t i = 0 ; i < fShapeY .size (); ++i) {
342
346
if (fShapeY [i] != 1 ) {
343
- out << std::string (i + 1 , ' ' ) << " for(size_t idx_" << fNY << i << " =0; idx_" << fNY << i << " <"
344
- << fShapeY [i] << " ; ++idx_" << fNY << i << " ){\n " ;
345
- compute_idx_Y += " idx_" + fNY + std::to_string (i) + " *" + stridesY[i] + " +" ;
347
+ out << std::string (i + 1 , ' ' ) << " for(size_t idx_" << i << " =0; idx_" << i << " <" << fShapeY [i]
348
+ << " ; ++idx_" << i << " ){\n " ;
349
+ compute_idx_Y += " idx_" + std::to_string (i) + " *" + stridesY[i] + " +" ;
346
350
}
347
351
}
348
352
compute_idx_Y.pop_back ();
0 commit comments