Skip to content

Commit 6510b5e

Browse files
authored
Merge pull request #549 from rest-for-physics/aezq_dsgmImprovements
TRestDataSet and TRestDataSetGainMap improvements
2 parents 19e9530 + 2f5641a commit 6510b5e

File tree

5 files changed

+57
-50
lines changed

5 files changed

+57
-50
lines changed

source/framework/analysis/src/TRestDataSetGainMap.cxx

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ void TRestDataSetGainMap::GenerateGainMap() {
213213
/// \brief Function to calibrate a dataset with this gain map.
214214
///
215215
/// \param dataSetFileName the name of the root file where the TRestDataSet to be calibrated is stored.
216+
/// If the file is not a TRestDataSet, it will be treated as a file pattern for several TRestRun files
217+
/// to generate a temporary TRestDataSet with the needed observables.
216218
/// \param outputFileName the name of the output (root) file where the calibrated TRestDataSet will be
217219
/// exported. If empty, the output file will be named as the input file plus the name of the
218220
/// TRestDataSetGainMap. E.g. "data/myDataSet.root" -> "data/myDataSet_<gmName>.root".
@@ -230,7 +232,17 @@ void TRestDataSetGainMap::CalibrateDataSet(const std::string& dataSetFileName, s
230232

231233
TRestDataSet dataSet;
232234
dataSet.EnableMultiThreading(true);
233-
dataSet.Import(dataSetFileName);
235+
236+
if (TRestTools::isDataSet(dataSetFileName)) {
237+
dataSet.Import(dataSetFileName);
238+
} else {
239+
RESTWarning << dataSetFileName << " is not a dataset. Generating a temporal one..." << RESTendl;
240+
// generate the dataset with the needed observables
241+
dataSet.SetFilePattern(dataSetFileName);
242+
dataSet.SetObservablesList({"*"}); // get all observables
243+
dataSet.GenerateDataSet();
244+
}
245+
234246
auto dataFrame = dataSet.GetDataFrame();
235247

236248
// Define a new column with the identifier (pmID) of the module for each row (event)
@@ -294,17 +306,9 @@ void TRestDataSetGainMap::CalibrateDataSet(const std::string& dataSetFileName, s
294306
}
295307

296308
// Export dataset. Exclude columns if requested.
297-
std::set<std::string> excludeCol; // vector with the explicit column names to be excluded
298309
auto columns = dataSet.GetDataFrame().GetColumnNames();
299-
// Get the columns to be excluded from the list of columns. It accepts wildcards "*" and "?"
300-
for (auto& eC : excludeColumns) {
301-
if (eC.find("*") != std::string::npos || eC.find("?") != std::string::npos) {
302-
for (auto& c : columns)
303-
if (MatchString(c, eC)) excludeCol.insert(c);
304-
} else if (std::find(columns.begin(), columns.end(), eC) != columns.end())
305-
excludeCol.insert(eC);
306-
}
307-
// Remove the calibObsName, calibObsNameFullSpc and pmIDname from the list of columns to be excluded
310+
std::set<std::string> excludeCol = TRestTools::GetMatchingStrings(columns, excludeColumns);
311+
// Never exclude the calibObsName, calibObsNameFullSpc and pmIDname
308312
excludeCol.erase(calibObsName);
309313
excludeCol.erase(calibObsNameFullSpc);
310314
excludeCol.erase(pmIDname);

source/framework/core/inc/TRestDataSet.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ class TRestDataSet : public TRestMetadata {
6060
/// It contains a list of the observables that will be added to the final tree or exported file
6161
std::vector<std::string> fObservablesList; //<
6262

63-
/// It contains a list of the process where all observables should be added
64-
std::vector<std::string> fProcessObservablesList; //<
65-
6663
/// A list of metadata members where filters will be applied
6764
std::vector<std::string> fFilterMetadata; //<
6865

@@ -172,7 +169,6 @@ class TRestDataSet : public TRestMetadata {
172169
inline auto GetFilePattern() const { return fFilePattern; }
173170
inline auto GetObservablesList() const { return fObservablesList; }
174171
inline auto GetFileSelection() const { return fFileSelection; }
175-
inline auto GetProcessObservablesList() const { return fProcessObservablesList; }
176172
inline auto GetFilterMetadata() const { return fFilterMetadata; }
177173
inline auto GetFilterContains() const { return fFilterContains; }
178174
inline auto GetFilterGreaterThan() const { return fFilterGreaterThan; }
@@ -215,6 +211,6 @@ class TRestDataSet : public TRestMetadata {
215211
TRestDataSet(const char* cfgFileName, const std::string& name = "");
216212
~TRestDataSet();
217213

218-
ClassDefOverride(TRestDataSet, 8);
214+
ClassDefOverride(TRestDataSet, 9);
219215
};
220216
#endif

source/framework/core/src/TRestDataSet.cxx

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -359,29 +359,14 @@ void TRestDataSet::GenerateDataSet() {
359359

360360
///// Disentangling process observables --> producing finalList
361361
TRestRun run(fFileSelection.front());
362-
std::vector<std::string> finalList;
363-
finalList.push_back("runOrigin");
364-
finalList.push_back("eventID");
365-
finalList.push_back("timeStamp");
362+
std::set<std::string> finalList;
363+
finalList.insert("runOrigin");
364+
finalList.insert("eventID");
365+
finalList.insert("timeStamp");
366366

367367
auto obsNames = run.GetAnalysisTree()->GetObservableNames();
368-
for (const auto& obs : fObservablesList) {
369-
if (std::find(obsNames.begin(), obsNames.end(), obs) != obsNames.end()) {
370-
finalList.push_back(obs);
371-
} else {
372-
RESTWarning << " Observable " << obs << " not found in observable list, skipping..." << RESTendl;
373-
}
374-
}
375-
376-
for (const auto& name : obsNames) {
377-
for (const auto& pcs : fProcessObservablesList) {
378-
if (name.find(pcs) == 0) finalList.push_back(name);
379-
}
380-
}
381-
382-
// Remove duplicated observables if any
383-
std::sort(finalList.begin(), finalList.end());
384-
finalList.erase(std::unique(finalList.begin(), finalList.end()), finalList.end());
368+
auto obsFromList = TRestTools::GetMatchingStrings(obsNames, fObservablesList);
369+
finalList.insert(obsFromList.begin(), obsFromList.end());
385370

386371
if (fMT)
387372
ROOT::EnableImplicitMT();
@@ -397,11 +382,11 @@ void TRestDataSet::GenerateDataSet() {
397382
// Adding new user columns added to the dataset
398383
for (const auto& [cName, cExpression] : fColumnNameExpressions) {
399384
RESTInfo << "Adding column to dataset: " << cName << RESTendl;
400-
finalList.emplace_back(cName);
385+
finalList.emplace(cName);
401386
fDataFrame = DefineColumn(cName, cExpression);
402387
}
403388

404-
RegenerateTree(finalList);
389+
RegenerateTree(std::vector<std::string>(finalList.begin(), finalList.end()));
405390

406391
RESTInfo << " - Dataset generated!" << RESTendl;
407392
}
@@ -672,21 +657,13 @@ void TRestDataSet::PrintMetadata() {
672657
RESTMetadata << " " << RESTendl;
673658

674659
if (!fObservablesList.empty()) {
675-
RESTMetadata << " Single observables added:" << RESTendl;
660+
RESTMetadata << " Observables added:" << RESTendl;
676661
RESTMetadata << " -------------------------" << RESTendl;
677662
for (const auto& l : fObservablesList) RESTMetadata << " - " << l << RESTendl;
678663

679664
RESTMetadata << " " << RESTendl;
680665
}
681666

682-
if (!fProcessObservablesList.empty()) {
683-
RESTMetadata << " Process observables added: " << RESTendl;
684-
RESTMetadata << " -------------------------- " << RESTendl;
685-
for (const auto& l : fProcessObservablesList) RESTMetadata << " - " << l << RESTendl;
686-
687-
RESTMetadata << " " << RESTendl;
688-
}
689-
690667
if (!fFilterMetadata.empty()) {
691668
RESTMetadata << " Metadata filters: " << RESTendl;
692669
RESTMetadata << " ----------------- " << RESTendl;
@@ -811,7 +788,10 @@ void TRestDataSet::InitFromConfigFile() {
811788

812789
std::vector<std::string> obsList = REST_StringHelper::Split(observables, ",");
813790

814-
for (const auto& l : obsList) fProcessObservablesList.push_back(l);
791+
for (const auto& l : obsList) {
792+
std::string processObsPattern = l + "_*";
793+
fObservablesList.push_back(processObsPattern);
794+
}
815795

816796
obsProcessDefinition = GetNextElement(obsProcessDefinition);
817797
}
@@ -1033,7 +1013,6 @@ TRestDataSet& TRestDataSet::operator=(TRestDataSet& dS) {
10331013
fFilePattern = dS.GetFilePattern();
10341014
fObservablesList = dS.GetObservablesList();
10351015
fFileSelection = dS.GetFileSelection();
1036-
fProcessObservablesList = dS.GetProcessObservablesList();
10371016
fFilterMetadata = dS.GetFilterMetadata();
10381017
fFilterContains = dS.GetFilterContains();
10391018
fFilterGreaterThan = dS.GetFilterGreaterThan();

source/framework/tools/inc/TRestTools.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include <map>
3131
#include <memory>
32+
#include <set>
3233
#include <string>
3334
#include <vector>
3435

@@ -85,6 +86,8 @@ class TRestTools {
8586
static std::string GetFileNameRoot(const std::string& fullname);
8687
static std::vector<std::string> GetObservablesInString(const std::string& observablesStr,
8788
bool removeDuplicates = true);
89+
static std::set<std::string> GetMatchingStrings(const std::vector<std::string>& stack,
90+
const std::vector<std::string>& wantedStrings);
8891

8992
static int GetBinaryFileColumns(std::string fname);
9093

source/framework/tools/src/TRestTools.cxx

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,31 @@ std::vector<std::string> TRestTools::GetObservablesInString(const std::string& o
861861
return obsList;
862862
}
863863

864+
////////////////////////////////////////////////////////////
865+
/// \brief Returns a set of strings that match the wanted strings from the stack.
866+
/// The wanted strings can contain wildcards "*" and "?".
867+
/// \param stack: vector of strings to be searched
868+
/// \param wantedStrings: vector of strings with the wanted strings to be matched.
869+
/// \return a set of strings that match the wanted strings
870+
/// e.g.
871+
/// Input: stack = {"x1", "x2", "x11", "y1", "y2", "y11", "z1", "z2"},
872+
/// wantedStrings = {"x*", "y?", "z1"},
873+
/// Output: {"x1", "x11", "x2", "y1", "y2", "z1"}
874+
///
875+
std::set<std::string> TRestTools::GetMatchingStrings(const std::vector<std::string>& stack,
876+
const std::vector<std::string>& wantedStrings) {
877+
std::set<std::string> result;
878+
for (auto& ws : wantedStrings) {
879+
if (ws.find("*") != std::string::npos || ws.find("?") != std::string::npos) {
880+
for (auto& c : stack)
881+
if (MatchString(c, ws)) result.insert(c);
882+
} else if (std::find(stack.begin(), stack.end(), ws) != stack.end())
883+
result.insert(ws);
884+
}
885+
// return std::vector<std::string>(result.begin(), result.end()); // convert to vector
886+
return result;
887+
}
888+
864889
///////////////////////////////////////////////
865890
/// \brief Returns the input string but without multiple slashes ("/")
866891
///

0 commit comments

Comments
 (0)