Skip to content

Commit 1b77790

Browse files
committed
[tmva][sofie] FIx setting batch size in Keras parsing
Fix for setting t he batch size when parsing a Keras model where the input batch size is not defined Add extra optional parameter in the Parse function to specify the batch size
1 parent e0347f7 commit 1b77790

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

tmva/pymva/inc/TMVA/RModelParser_Keras.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ namespace PyKeras{
4545
/// Parser function for translatng Keras .h5 model into a RModel object.
4646
/// Accepts the file location of a Keras model and returns the
4747
/// equivalent RModel object.
48-
RModel Parse(std::string filename);
48+
/// One can specify as option a batch size that can be used when the input Keras model
49+
/// has not a defined input batch size : e.g. for input = (input_dim,)
50+
RModel Parse(std::string filename, int batch_size = -1);
4951

5052
}//PyKeras
5153
}//SOFIE

tmva/pymva/src/RModelParser_Keras.cxx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -791,12 +791,15 @@ std::unique_ptr<ROperator> MakeKerasIdentity(PyObject* fLayer)
791791
/// For adding the Output Tensor infos, only the names of the model's output
792792
/// tensors are extracted and are then passed into `AddOutputTensorNameList()`.
793793
///
794+
/// Provide optionally a batch size that can be used to overwrite the one given by the
795+
/// model. If a batch size is not given 1 is used if the model does not provide a batch size
796+
///
794797
/// Example Usage:
795798
/// ~~~ {.cpp}
796799
/// using TMVA::Experimental::SOFIE;
797800
/// RModel model = PyKeras::Parse("trained_model_dense.h5");
798801
/// ~~~
799-
RModel Parse(std::string filename){
802+
RModel Parse(std::string filename, int batch_size){
800803

801804
char sep = '/';
802805
#ifdef _WIN32
@@ -966,8 +969,11 @@ RModel Parse(std::string filename){
966969
// Getting the shape vector from the Tuple object
967970
std::vector<size_t>fInputShape = GetDataFromTuple(fPInputShapes);
968971
if (static_cast<int>(fInputShape[0]) <= 0){
969-
fInputShape[0] = 1;
970-
std::cout << "Model has not a defined batch size, assume is 1 - input shape : "
972+
fInputShape[0] = std::max(batch_size,1);
973+
std::cout << "Model has not a defined batch size ";
974+
if (batch_size <=0) std::cout << " assume is 1 ";
975+
else std::cout << " use given value of " << batch_size;
976+
std::cout << " - input shape for tensor " << fInputName << " : "
971977
<< TMVA::Experimental::SOFIE::ConvertShapeToString(fInputShape) << std::endl;
972978
}
973979
rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);
@@ -995,8 +1001,11 @@ RModel Parse(std::string filename){
9951001

9961002
std::vector<size_t>fInputShape = GetDataFromTuple(fInputShapeTuple);
9971003
if (static_cast<int>(fInputShape[0]) <= 0){
998-
fInputShape[0] = 1;
999-
std::cout << "Model has not a defined batch size, assume is 1 - input shape for tensor "
1004+
fInputShape[0] = std::max(batch_size,1);
1005+
std::cout << "Model has not a defined batch size ";
1006+
if (batch_size <=0) std::cout << " assume is 1 ";
1007+
else std::cout << " use given value of " << batch_size;
1008+
std::cout << " - input shape for tensor "
10001009
<< fInputName << " : " << TMVA::Experimental::SOFIE::ConvertShapeToString(fInputShape) << std::endl;
10011010
}
10021011
rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);

tmva/pymva/test/TestRModelParserKeras.C

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ void GenerateModels() {
2424
TEST(RModelParser_Keras, SEQUENTIAL)
2525
{
2626
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
27+
// input is 8 x batch size that is fixed to be 4
2728
std::vector<float> inputSequential = { 0.12107884, 0.89718615, 0.89123899, 0.32197549,
2829
0.17891638, 0.83555135, 0.98680066, 0.14496809,
2930
0.07255503, 0.55386989, 0.6628149 , 0.29843291,
@@ -38,7 +39,7 @@ TEST(RModelParser_Keras, SEQUENTIAL)
3839
if (gSystem->AccessPathName("KerasModelSequential.h5",kFileExists))
3940
GenerateModels();
4041

41-
TMVA::Experimental:: RSofieReader r("KerasModelSequential.h5");
42+
TMVA::Experimental:: RSofieReader r("KerasModelSequential.h5",{{4,8}});
4243
std::vector<float> outputSequential = r.Compute(inputSequential);
4344

4445

tmva/tmva/inc/TMVA/RSofieReader.hxx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,12 @@ public:
113113
if (gSystem->Load("libPyMVA") < 0) {
114114
throw std::runtime_error("RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
115115
}
116-
parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path + "\"); \n";
116+
// assume batch size is first entry in first input !
117+
std::string batch_size = "-1";
118+
if (!inputShapes.empty() && ! inputShapes[0].empty())
119+
batch_size = std::to_string(inputShapes[0][0]);
120+
parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
121+
"\"," + batch_size + "); \n";
117122
}
118123
else if (type == kPt) {
119124
// use PyTorch direct parser

0 commit comments

Comments
 (0)