@@ -791,12 +791,15 @@ std::unique_ptr<ROperator> MakeKerasIdentity(PyObject* fLayer)
791791// / For adding the Output Tensor infos, only the names of the model's output
792792// / tensors are extracted and are then passed into `AddOutputTensorNameList()`.
793793// /
794+ // / Provide optionally a batch size that can be used to overwrite the one given by the
795+ // / model. If a batch size is not given 1 is used if the model does not provide a batch size
796+ // /
794797// / Example Usage:
795798// / ~~~ {.cpp}
796799// / using TMVA::Experimental::SOFIE;
797800// / RModel model = PyKeras::Parse("trained_model_dense.h5");
798801// / ~~~
799- RModel Parse (std::string filename){
802+ RModel Parse (std::string filename, int batch_size ){
800803
801804 char sep = ' /' ;
802805 #ifdef _WIN32
@@ -966,8 +969,11 @@ RModel Parse(std::string filename){
966969 // Getting the shape vector from the Tuple object
967970 std::vector<size_t >fInputShape = GetDataFromTuple (fPInputShapes );
968971 if (static_cast <int >(fInputShape [0 ]) <= 0 ){
969- fInputShape [0 ] = 1 ;
970- std::cout << " Model has not a defined batch size, assume is 1 - input shape : "
972+ fInputShape [0 ] = std::max (batch_size,1 );
973+ std::cout << " Model has not a defined batch size " ;
974+ if (batch_size <=0 ) std::cout << " assume is 1 " ;
975+ else std::cout << " use given value of " << batch_size;
976+ std::cout << " - input shape for tensor " << fInputName << " : "
971977 << TMVA::Experimental::SOFIE::ConvertShapeToString (fInputShape ) << std::endl;
972978 }
973979 rmodel.AddInputTensorInfo (fInputName , ETensorType::FLOAT, fInputShape );
@@ -995,8 +1001,11 @@ RModel Parse(std::string filename){
9951001
9961002 std::vector<size_t >fInputShape = GetDataFromTuple (fInputShapeTuple );
9971003 if (static_cast <int >(fInputShape [0 ]) <= 0 ){
998- fInputShape [0 ] = 1 ;
999- std::cout << " Model has not a defined batch size, assume is 1 - input shape for tensor "
1004+ fInputShape [0 ] = std::max (batch_size,1 );
1005+ std::cout << " Model has not a defined batch size " ;
1006+ if (batch_size <=0 ) std::cout << " assume is 1 " ;
1007+ else std::cout << " use given value of " << batch_size;
1008+ std::cout << " - input shape for tensor "
10001009 << fInputName << " : " << TMVA::Experimental::SOFIE::ConvertShapeToString (fInputShape ) << std::endl;
10011010 }
10021011 rmodel.AddInputTensorInfo (fInputName , ETensorType::FLOAT, fInputShape );
0 commit comments