@@ -262,24 +262,37 @@ void RModel::Initialize(int batchSize, bool verbose) {
262262 // loop on inputs and see if shape can be full specified
263263 // if the batch size is provided it can be used to specify the full shape
264264 // Add the full specified tensors in fReadyInputTensors collection
265- for (auto &input : fInputTensorInfos ) {
265+ auto originalInputTensorInfos = fInputTensorInfos ; // need to copy because we may delete elements
266+ for (auto &input : originalInputTensorInfos) {
267+ if (verbose) std::cout << " looking at the tensor " << input.first << std::endl;
266268 // if a batch size is provided convert batch size
267269 // assume is parametrized as "bs" or "batch_size"
268270 if (batchSize > 0 ) {
269271 // std::vector<Dim> shape;
270272 // shape.reserve(input.second.shape.size());
271- for (auto &d : input.second .shape ) {
272- if (d.isParam && (d.param == " bs" || d.param == " batch_size" )) {
273- d = Dim{static_cast <size_t >(batchSize)};
273+ // assume first parameter is teh batch size
274+ if (!input.second .shape .empty ()) {
275+ auto & d0 = input.second .shape [0 ];
276+ if (d0.isParam ) {
277+ if (verbose) std::cout << " Fix the batch size to " << batchSize << std::endl;
278+ d0 = Dim{static_cast <size_t >(batchSize)};
279+ }
280+ else { // look for cases that a bs or bath_size is specified in tensor shape
281+ for (auto &d : input.second .shape ) {
282+ if (d.isParam && (d.param == " bs" || d.param == " batch_size" )) {
283+ d = Dim{static_cast <size_t >(batchSize)};
284+ if (verbose) std::cout << " Input shape has bs or batch_size as names. Fix the batch size to " << batchSize << std::endl;
285+ }
286+ }
274287 }
275288 }
276289 }
277290 auto shape = ConvertShapeToInt (input.second .shape );
278291 if (!shape.empty ()) {
279- // add to the ready input tensor informations
280- AddInputTensorInfo (input.first , input.second .type , shape);
281- // remove from the tensor info
292+ // remove from the tensor info old dynamic shape
282293 fInputTensorInfos .erase (input.first );
294+ // add to the ready input tensor information the new fixed shape
295+ AddInputTensorInfo (input.first , input.second .type , shape);
283296 }
284297 // store the parameters of the input tensors
285298 else {
@@ -633,7 +646,7 @@ void RModel::ReadInitializedTensorsFromFile(long pos) {
633646 fGC += " std::ifstream f;\n " ;
634647 fGC += " f.open(filename);\n " ;
635648 fGC += " if (!f.is_open()) {\n " ;
636- fGC += " throw std::runtime_error(\" tmva-sofie failed to open file for input weights\" );\n " ;
649+ fGC += " throw std::runtime_error(\" tmva-sofie failed to open file \" + filename + \" for input weights\" );\n " ;
637650 fGC += " }\n " ;
638651
639652 if (fIsGNNComponent ) {
@@ -769,7 +782,7 @@ long RModel::WriteInitializedTensorsToFile(std::string filename) {
769782 }
770783 if (!f.is_open ())
771784 throw
772- std::runtime_error (" tmva-sofie failed to open file for tensor weight data" );
785+ std::runtime_error (" tmva-sofie failed to open file " + filename + " for tensor weight data" );
773786 for (auto & i: fInitializedTensors ) {
774787 if (i.second .type () == ETensorType::FLOAT) {
775788 size_t length = 1 ;
0 commit comments