Skip to content

Commit 524e71d

Browse files
committed
[tmva][sofie] Improve handling of dynamic parameters in Session constructor
Store the list of Dim parameter names in a vector to keep the order used to generate the code and have an order list of parameters in the Session constructor
1 parent 8a5d53c commit 524e71d

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

tmva/sofie/inc/TMVA/RModel.hxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ private:
2525
std::unordered_map<std::string, InitializedTensor> fInitializedTensors;
2626
std::unordered_map<std::string, TensorInfo> fIntermediateTensorInfos;
2727
std::unordered_map<std::string, DynamicTensorInfo> fDynamicTensorInfos;
28-
std::unordered_map<std::string, std::string>
29-
fShapeParams; // parameters defining the dynamic shape (e.g. batch size), store also its default value
28+
std::unordered_map<std::string, std::string> fShapeParams; // parameters defining the dynamic shape (e.g. batch size), store also its default value
29+
std::vector<std::string> fDimShapeNames; // parameter names used to define the shapes
3030
std::vector<std::string> fOutputTensorNames;
3131
std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
3232

tmva/sofie/src/RModel.cxx

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ void RModel::AddDynamicTensor(std::string tensor_name, ETensorType type, std::ve
244244
// register it
245245
if (d.dim != size_t(-1)) {
246246
fShapeParams[d.param] = std::to_string(d.dim);
247+
// add also in teh vector list (used to keep the order)
248+
fDimShapeNames.push_back(d.param);
247249
}
248250
}
249251
}
@@ -446,8 +448,12 @@ void RModel::Initialize(const std::map<std::string, size_t> & inputParams, bool
446448
else {
447449
// store the found parametric shape parameters
448450
for (auto &d : input.second.shape) {
449-
if (d.isParam)
450-
fShapeParams[d.param] = std::to_string(d.dim);
451+
if (d.isParam) {
452+
if (fShapeParams.count(d.param) == 0) {
453+
fDimShapeNames.push_back(d.param);
454+
fShapeParams[d.param] = std::to_string(d.dim);
455+
}
456+
}
451457
}
452458
}
453459
}
@@ -864,10 +870,10 @@ void RModel::GenerateSessionCode()
864870
}
865871
// add initialization of shape parameters
866872
// assume all parameters are of type size_t
867-
if (!fShapeParams.empty()) {
868-
for (auto &p : fShapeParams) {
873+
if (!fDimShapeNames.empty()) {
874+
for (auto &p : fDimShapeNames) {
869875
fGC += ",\n";
870-
fGC += " size_t " + p.first + " = " + p.second;
876+
fGC += " size_t " + p + " = " + fShapeParams[p];
871877
}
872878
}
873879
fGC += ") {\n";

0 commit comments

Comments
 (0)