Skip to content

Commit e0347f7

Browse files
committed
[tmva][sofie] Apply a fix when batch size is given by user
When input shape is parametric, give possibility to fix the batch size if provided by user in RModel::Generate
1 parent 8a1fb82 commit e0347f7

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

tmva/sofie/src/RModel.cxx

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,24 +262,37 @@ void RModel::Initialize(int batchSize, bool verbose) {
262262
// loop on inputs and see if shape can be full specified
263263
// if the batch size is provided it can be used to specify the full shape
264264
// Add the full specified tensors in fReadyInputTensors collection
265-
for (auto &input : fInputTensorInfos) {
265+
auto originalInputTensorInfos = fInputTensorInfos; // need to copy because we may delete elements
266+
for (auto &input : originalInputTensorInfos) {
267+
if (verbose) std::cout << "looking at the tensor " << input.first << std::endl;
266268
// if a batch size is provided convert batch size
267269
// assume is parametrized as "bs" or "batch_size"
268270
if (batchSize > 0) {
269271
// std::vector<Dim> shape;
270272
// shape.reserve(input.second.shape.size());
271-
for (auto &d : input.second.shape) {
272-
if (d.isParam && (d.param == "bs" || d.param == "batch_size")) {
273-
d = Dim{static_cast<size_t>(batchSize)};
273+
// assume first parameter is teh batch size
274+
if (!input.second.shape.empty()) {
275+
auto & d0 = input.second.shape[0];
276+
if (d0.isParam) {
277+
if (verbose) std::cout << "Fix the batch size to " << batchSize << std::endl;
278+
d0 = Dim{static_cast<size_t>(batchSize)};
279+
}
280+
else { // look for cases that a bs or bath_size is specified in tensor shape
281+
for (auto &d : input.second.shape) {
282+
if (d.isParam && (d.param == "bs" || d.param == "batch_size")) {
283+
d = Dim{static_cast<size_t>(batchSize)};
284+
if (verbose) std::cout << "Input shape has bs or batch_size as names. Fix the batch size to " << batchSize << std::endl;
285+
}
286+
}
274287
}
275288
}
276289
}
277290
auto shape = ConvertShapeToInt(input.second.shape);
278291
if (!shape.empty()) {
279-
// add to the ready input tensor informations
280-
AddInputTensorInfo(input.first, input.second.type, shape);
281-
// remove from the tensor info
292+
// remove from the tensor info old dynamic shape
282293
fInputTensorInfos.erase(input.first);
294+
// add to the ready input tensor information the new fixed shape
295+
AddInputTensorInfo(input.first, input.second.type, shape);
283296
}
284297
// store the parameters of the input tensors
285298
else {
@@ -633,7 +646,7 @@ void RModel::ReadInitializedTensorsFromFile(long pos) {
633646
fGC += " std::ifstream f;\n";
634647
fGC += " f.open(filename);\n";
635648
fGC += " if (!f.is_open()) {\n";
636-
fGC += " throw std::runtime_error(\"tmva-sofie failed to open file for input weights\");\n";
649+
fGC += " throw std::runtime_error(\"tmva-sofie failed to open file \" + filename + \" for input weights\");\n";
637650
fGC += " }\n";
638651

639652
if(fIsGNNComponent) {
@@ -769,7 +782,7 @@ long RModel::WriteInitializedTensorsToFile(std::string filename) {
769782
}
770783
if (!f.is_open())
771784
throw
772-
std::runtime_error("tmva-sofie failed to open file for tensor weight data");
785+
std::runtime_error("tmva-sofie failed to open file " + filename + " for tensor weight data");
773786
for (auto& i: fInitializedTensors) {
774787
if (i.second.type() == ETensorType::FLOAT) {
775788
size_t length = 1;

0 commit comments

Comments
 (0)