Skip to content

Commit fd48a6d

Browse files
committed
[tmva][sofie] Fix initialisation of BasicBinary operator
In the initialization of BasicBinary operator do not change the inputshapes of the tensors, just compute the output shape. Do instead broadcasting of tensors and implement operator at initialization phase if input tensors are constant
1 parent bdee40f commit fd48a6d

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,33 +129,37 @@ public:
129129
// check if need to broadcast at initialization time if shapes are known and different
130130
// (we could broadcast the tensor tensor to maximum values of dynamic shapes - to be done)
131131
// case of known shapes
132+
// if shapes are known find the output shape from broadcasting
132133
if (dynamicInputs == 0) {
133134
auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB);
134135
fBroadcastFlag = ret.first;
135136
fShapeY = ret.second;
136-
bool broadcast = ret.first > 0;
137-
if (broadcast) {
138-
// Y is the common shape of A and B
139-
bool broadcastA = ret.first & 2;
140-
bool broadcastB = ret.first & 1;
141-
// Broadcast A to Y
142-
if (broadcastA) {
143-
fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
144-
if (model.IsConstantTensor(fNA)) {
137+
if (model.IsConstantTensor(fNA) && model.IsConstantTensor(fNB)) {
138+
bool broadcast = fBroadcastFlag > 0;
139+
if (broadcast) {
140+
// Y is the common shape of A and B
141+
bool broadcastA = fBroadcastFlag & 2;
142+
bool broadcastB = fBroadcastFlag & 1;
143+
// Broadcast A to Y
144+
if (broadcastA) {
145+
fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
145146
auto data = model.GetInitializedTensorData(fNA);
146147
std::shared_ptr<void> broadcastedData(
147148
UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
148149
std::default_delete<T[]>());
150+
if (model.Verbose())
151+
std::cout << "broadcasted data A " << ConvertShapeToString(fShapeY) << " : "
152+
<< ConvertValuesToString(ConvertShapeToLength(fShapeY),
153+
static_cast<T *>(broadcastedData.get()))
154+
<< std::endl;
149155
// Update the data and the shape of A
150156
model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
151157
fShapeA = fShapeY;
152158
fDimShapeA = ConvertShapeToDim(fShapeA);
153159
}
154-
}
155-
// Broadcast B to Y
156-
if (broadcastB) {
157-
fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
158-
if (model.IsConstantTensor(fNB)) {
160+
// Broadcast B to Y
161+
if (broadcastB) {
162+
fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
159163
auto data = model.GetInitializedTensorData(fNB);
160164
if (model.Verbose())
161165
std::cout << "data B " << ConvertShapeToString(fShapeB) << " : "
@@ -174,12 +178,11 @@ public:
174178
fShapeB = fShapeY;
175179
fDimShapeB = ConvertShapeToDim(fShapeB);
176180
}
181+
} else {
182+
fShapeY = fShapeA;
177183
}
178-
} else {
179-
fShapeY = fShapeA;
180-
}
181-
// check case of constant output (if all inputs are defined)
182-
if (model.IsConstantTensor(fNA) && model.IsConstantTensor(fNB)) {
184+
// tensors are constant: perform here the binary operation
185+
183186
const std::string &nameA = fNBroadcastedA.empty() ? fNA : fNBroadcastedA;
184187
const std::string &nameB = fNBroadcastedB.empty() ? fNB : fNBroadcastedB;
185188
auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
@@ -189,7 +192,7 @@ public:
189192
dataY[i] = BinaryOperatorTrait<T, Op>::Func(dataA[i], dataB[i]);
190193
}
191194
model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
192-
// flag tensors to not be written in a fil
195+
// flag tensors to not be written in the weight file
193196
model.SetNotWritableInitializedTensor(nameA);
194197
model.SetNotWritableInitializedTensor(nameB);
195198
fIsOutputConstant = true;
@@ -199,17 +202,17 @@ public:
199202
<< ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl;
200203
}
201204
} else {
205+
// case of defined and non-constant tensors
202206
model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
203207
if (model.Verbose()) {
204208
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
205209
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " "
206210
<< ConvertShapeToString(fShapeY) << std::endl;
207211
}
212+
// we convert non-dim shapes to Dim shapes
213+
fDimShapeY = ConvertShapeToDim(fShapeY);
208214
}
209-
// we convert non-dim shapes to Dim shapes
210-
fDimShapeY = ConvertShapeToDim(fShapeY);
211-
} // endif of non-parametric shapes
212-
else {
215+
} else {
213216
// case A or B have dynamic shapes. We need to broadcast if shape are not same
214217
auto ret = UTILITY::MultidirectionalBroadcastShape(fDimShapeA, fDimShapeB);
215218
fBroadcastFlag = ret.first;
@@ -274,7 +277,8 @@ public:
274277
throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
275278
}
276279
std::stringstream out;
277-
out << SP << "\n//------ " << BinaryOperatorTrait<T, Op>::Name() << "\n";
280+
out << SP << "\n//------ " << opName << " " << BinaryOperatorTrait<T, Op>::Name() << " --> "
281+
<< ConvertDimShapeToString(fDimShapeY) << "\n";
278282
auto length = ConvertDimShapeToLength(fDimShapeY);
279283
std::string typeName = TensorType<T>::Name();
280284

@@ -323,7 +327,7 @@ public:
323327
if (fShapeA[i] == 1)
324328
continue;
325329
compute_idx_A +=
326-
" idx_" + fNY + std::to_string(i + (fShapeY.size() - fShapeA.size())) + " * " + stridesA[i] + " +";
330+
" idx_" + std::to_string(i + (fShapeY.size() - fShapeA.size())) + " * " + stridesA[i] + " +";
327331
}
328332
compute_idx_A.pop_back();
329333
}
@@ -334,15 +338,15 @@ public:
334338
if (fShapeB[i] == 1)
335339
continue;
336340
compute_idx_B +=
337-
" idx_" + fNY + std::to_string(i + (fShapeY.size() - fShapeB.size())) + " * " + stridesB[i] + " +";
341+
" idx_" + std::to_string(i + (fShapeY.size() - fShapeB.size())) + " * " + stridesB[i] + " +";
338342
}
339343
compute_idx_B.pop_back();
340344
}
341345
for (size_t i = 0; i < fShapeY.size(); ++i) {
342346
if (fShapeY[i] != 1) {
343-
out << std::string(i + 1, ' ') << "for(size_t idx_" << fNY << i << "=0; idx_" << fNY << i << "<"
344-
<< fShapeY[i] << "; ++idx_" << fNY << i << "){\n";
345-
compute_idx_Y += "idx_" + fNY + std::to_string(i) + "*" + stridesY[i] + "+";
347+
out << std::string(i + 1, ' ') << "for(size_t idx_" << i << "=0; idx_" << i << "<" << fShapeY[i]
348+
<< "; ++idx_" << i << "){\n";
349+
compute_idx_Y += "idx_" + std::to_string(i) + "*" + stridesY[i] + "+";
346350
}
347351
}
348352
compute_idx_Y.pop_back();

0 commit comments

Comments
 (0)