Skip to content

Commit ba52dca

Browse files
committed
fix a bug when specifying batch size in RModel with input param bs
1 parent 2ae6678 commit ba52dca

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

tmva/sofie/src/RModel.cxx

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,23 +263,37 @@ void RModel::Initialize(int batchSize, bool verbose) {
263263
// if the batch size is provided it can be used to specify the full shape
264264
// Add the full specified tensors in fReadyInputTensors collection
265265
for (auto &input : fInputTensorInfos) {
266+
std::cout << "looking at the tensor " << input.first << std::endl;
266267
// if a batch size is provided convert batch size
267268
// assume is parametrized as "bs" or "batch_size"
268269
if (batchSize > 0) {
269270
// std::vector<Dim> shape;
270271
// shape.reserve(input.second.shape.size());
271-
for (auto &d : input.second.shape) {
272-
if (d.isParam && (d.param == "bs" || d.param == "batch_size")) {
273-
d = Dim{static_cast<size_t>(batchSize)};
272+
// assume first parameter is teh batch size
273+
if (!input.second.shape.empty()) {
274+
auto & d0 = input.second.shape[0];
275+
if (d0.isParam) {
276+
if (verbose) std::cout << "Fix the batch size to " << batchSize << std::endl;
277+
d0 = Dim{static_cast<size_t>(batchSize)};
278+
}
279+
else { // look for cases that a bs or bath_size is specified in tensor shape
280+
for (auto &d : input.second.shape) {
281+
if (d.isParam && (d.param == "bs" || d.param == "batch_size")) {
282+
d = Dim{static_cast<size_t>(batchSize)};
283+
if (verbose) std::cout << "Input shape has bs or batch_size as names. Fix the batch size to " << batchSize << std::endl;
284+
}
285+
}
274286
}
275287
}
276288
}
277289
auto shape = ConvertShapeToInt(input.second.shape);
278290
if (!shape.empty()) {
279-
// add to the ready input tensor informations
280-
AddInputTensorInfo(input.first, input.second.type, shape);
281-
// remove from the tensor info
291+
#if 0
292+
// remove from the tensor info old dynamic shape
282293
fInputTensorInfos.erase(input.first);
294+
// add to the ready input tensor information the new fixed shape
295+
AddInputTensorInfo(input.first, input.second.type, shape);
296+
#endif
283297
}
284298
// store the parameters of the input tensors
285299
else {

0 commit comments

Comments
 (0)