Skip to content

Commit 9591b5a

Browse files
authored
Merge pull request #524 from rest-for-physics/jgalan_dataset_updates
Component and dataset upgrades
2 parents 8b97d34 + fe6aad7 commit 9591b5a

File tree

4 files changed

+100
-34
lines changed

4 files changed

+100
-34
lines changed

source/framework/core/inc/TRestDataSet.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class TRestDataSet : public TRestMetadata {
112112
Bool_t fExternal = false; //<
113113

114114
/// The resulting RDF::RNode object after initialization
115-
ROOT::RDF::RNode fDataSet = ROOT::RDataFrame(0); //!
115+
ROOT::RDF::RNode fDataFrame = ROOT::RDataFrame(0); //!
116116

117117
/// A pointer to the generated tree
118118
TChain* fTree = nullptr; //!
@@ -122,12 +122,14 @@ class TRestDataSet : public TRestMetadata {
122122
protected:
123123
virtual std::vector<std::string> FileSelection();
124124

125+
void RegenerateTree(std::vector<std::string> finalList = {});
126+
125127
public:
126128
/// Gives access to the RDataFrame
127129
ROOT::RDF::RNode GetDataFrame() const {
128130
if (!fExternal && fTree == nullptr)
129131
RESTWarning << "DataFrame has not been yet initialized" << RESTendl;
130-
return fDataSet;
132+
return fDataFrame;
131133
}
132134

133135
void EnableMultiThreading(Bool_t enable = true) { fMT = enable; }
@@ -152,7 +154,7 @@ class TRestDataSet : public TRestMetadata {
152154
}
153155

154156
/// Number of variables (or observables)
155-
size_t GetNumberOfColumns() { return fDataSet.GetColumnNames().size(); }
157+
size_t GetNumberOfColumns() { return fDataFrame.GetColumnNames().size(); }
156158

157159
/// Number of variables (or observables)
158160
size_t GetNumberOfBranches() { return GetNumberOfColumns(); }
@@ -187,7 +189,7 @@ class TRestDataSet : public TRestMetadata {
187189

188190
void SetTotalTimeInSeconds(Double_t seconds) { fTotalDuration = seconds; }
189191
void SetDataFrame(const ROOT::RDF::RNode& dS) {
190-
fDataSet = dS;
192+
fDataFrame = dS;
191193
fExternal = true;
192194
}
193195

@@ -198,8 +200,12 @@ class TRestDataSet : public TRestMetadata {
198200
void Export(const std::string& filename, std::vector<std::string> excludeColumns = {});
199201

200202
ROOT::RDF::RNode MakeCut(const TRestCut* cut);
203+
ROOT::RDF::RNode ApplyRange(size_t from, size_t to);
204+
ROOT::RDF::RNode Range(size_t from, size_t to);
201205
ROOT::RDF::RNode DefineColumn(const std::string& columnName, const std::string& formula);
202206

207+
size_t GetEntries();
208+
203209
void PrintMetadata() override;
204210
void Initialize() override;
205211

@@ -209,6 +215,6 @@ class TRestDataSet : public TRestMetadata {
209215
TRestDataSet(const char* cfgFileName, const std::string& name = "");
210216
~TRestDataSet();
211217

212-
ClassDefOverride(TRestDataSet, 7);
218+
ClassDefOverride(TRestDataSet, 8);
213219
};
214220
#endif

source/framework/core/src/TRestDataSet.cxx

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -382,30 +382,40 @@ void TRestDataSet::GenerateDataSet() {
382382
ROOT::DisableImplicitMT();
383383

384384
RESTInfo << "Initializing dataset" << RESTendl;
385-
fDataSet = ROOT::RDataFrame("AnalysisTree", fFileSelection);
385+
fDataFrame = ROOT::RDataFrame("AnalysisTree", fFileSelection);
386386

387387
RESTInfo << "Making cuts" << RESTendl;
388-
fDataSet = MakeCut(fCut);
388+
fDataFrame = MakeCut(fCut);
389389

390390
// Adding new user columns added to the dataset
391391
for (const auto& [cName, cExpression] : fColumnNameExpressions) {
392392
RESTInfo << "Adding column to dataset: " << cName << RESTendl;
393393
finalList.emplace_back(cName);
394-
fDataSet = DefineColumn(cName, cExpression);
394+
fDataFrame = DefineColumn(cName, cExpression);
395395
}
396396

397+
RegenerateTree(finalList);
398+
399+
RESTInfo << " - Dataset generated!" << RESTendl;
400+
}
401+
402+
///////////////////////////////////////////////
403+
/// \brief It regenerates the tree so that it is an exact copy of the present DataFrame
404+
///
405+
void TRestDataSet::RegenerateTree(std::vector<std::string> finalList) {
397406
RESTInfo << "Generating snapshot." << RESTendl;
398407
std::string user = getenv("USER");
399408
std::string fOutName = "/tmp/rest_output_" + user + ".root";
400-
fDataSet.Snapshot("AnalysisTree", fOutName, finalList);
409+
if (!finalList.empty())
410+
fDataFrame.Snapshot("AnalysisTree", fOutName, finalList);
411+
else
412+
fDataFrame.Snapshot("AnalysisTree", fOutName);
401413

402414
RESTInfo << "Re-importing analysis tree." << RESTendl;
403-
fDataSet = ROOT::RDataFrame("AnalysisTree", fOutName);
415+
fDataFrame = ROOT::RDataFrame("AnalysisTree", fOutName);
404416

405417
TFile* f = TFile::Open(fOutName.c_str());
406418
fTree = (TChain*)f->Get("AnalysisTree");
407-
408-
RESTInfo << " - Dataset generated!" << RESTendl;
409419
}
410420

411421
///////////////////////////////////////////////
@@ -517,14 +527,32 @@ std::vector<std::string> TRestDataSet::FileSelection() {
517527
return fFileSelection;
518528
}
519529

530+
///////////////////////////////////////////////
531+
/// \brief This method returns a RDataFrame node with the number of
532+
/// samples inside the dataset by selecting a range. It will not
533+
/// modify internally the dataset. See ApplyRange to modify internally
534+
/// the dataset.
535+
///
536+
ROOT::RDF::RNode TRestDataSet::Range(size_t from, size_t to) { return fDataFrame.Range(from, to); }
537+
538+
///////////////////////////////////////////////
539+
/// \brief This method reduces the number of samples inside the
540+
/// dataset by selecting a range.
541+
///
542+
ROOT::RDF::RNode TRestDataSet::ApplyRange(size_t from, size_t to) {
543+
fDataFrame = fDataFrame.Range(from, to);
544+
RegenerateTree();
545+
return fDataFrame;
546+
}
547+
520548
///////////////////////////////////////////////
521549
/// \brief This function applies a TRestCut to the dataframe
522550
/// and returns a dataframe with the applied cuts. Note that
523551
/// the cuts are not applied directly to the dataframe on
524-
/// TRestDataSet, to do so you should do fDataSet = MakeCut(fCut);
552+
/// TRestDataSet, to do so you should do fDataFrame = MakeCut(fCut);
525553
///
526554
ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
527-
auto df = fDataSet;
555+
auto df = fDataFrame;
528556

529557
if (cut == nullptr) return df;
530558

@@ -561,6 +589,20 @@ ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
561589
return df;
562590
}
563591

592+
///////////////////////////////////////////////
593+
/// \brief It returns the number of entries found inside fDataFrame
594+
/// and prints out a warning if the number of entries inside the
595+
/// tree is not the same.
596+
///
597+
size_t TRestDataSet::GetEntries() {
598+
auto nEntries = fDataFrame.Count();
599+
if (*nEntries == (long long unsigned int)GetTree()->GetEntries()) return *nEntries;
600+
RESTWarning << "TRestDataSet::GetEntries. Number of tree entries is not the same as RDataFrame entries."
601+
<< RESTendl;
602+
RESTWarning << "Returning RDataFrame entries" << RESTendl;
603+
return *nEntries;
604+
}
605+
564606
///////////////////////////////////////////////
565607
/// \brief This function will add a new column to the RDataFrame using
566608
/// the same scheme as the usual RDF::Define method, but it will on top of
@@ -574,7 +616,7 @@ ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
574616
/// \endcode
575617
///
576618
ROOT::RDF::RNode TRestDataSet::DefineColumn(const std::string& columnName, const std::string& formula) {
577-
auto df = fDataSet;
619+
auto df = fDataFrame;
578620

579621
std::string evalFormula = formula;
580622
for (auto const& [name, properties] : fQuantity)
@@ -819,7 +861,7 @@ void TRestDataSet::InitFromConfigFile() {
819861
void TRestDataSet::Export(const std::string& filename, std::vector<std::string> excludeColumns) {
820862
RESTInfo << "Exporting dataset" << RESTendl;
821863

822-
std::vector<std::string> columns = fDataSet.GetColumnNames();
864+
std::vector<std::string> columns = fDataFrame.GetColumnNames();
823865
if (!excludeColumns.empty()) {
824866
columns.erase(std::remove_if(columns.begin(), columns.end(),
825867
[&excludeColumns](std::string elem) {
@@ -831,10 +873,10 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
831873
RESTInfo << "Re-Generating snapshot." << RESTendl;
832874
std::string user = getenv("USER");
833875
std::string fOutName = "/tmp/rest_output_" + user + ".root";
834-
fDataSet.Snapshot("AnalysisTree", fOutName, columns);
876+
fDataFrame.Snapshot("AnalysisTree", fOutName, columns);
835877

836878
RESTInfo << "Re-importing analysis tree." << RESTendl;
837-
fDataSet = ROOT::RDataFrame("AnalysisTree", fOutName);
879+
fDataFrame = ROOT::RDataFrame("AnalysisTree", fOutName);
838880

839881
TFile* f = TFile::Open(fOutName.c_str());
840882
fTree = (TChain*)f->Get("AnalysisTree");
@@ -846,7 +888,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
846888
RESTInfo << "Re-Generating snapshot." << RESTendl;
847889
std::string user = getenv("USER");
848890
std::string fOutName = "/tmp/rest_output_" + user + ".root";
849-
fDataSet.Snapshot("AnalysisTree", fOutName);
891+
fDataFrame.Snapshot("AnalysisTree", fOutName);
850892

851893
TFile* f = TFile::Open(fOutName.c_str());
852894
fTree = (TChain*)f->Get("AnalysisTree");
@@ -910,7 +952,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
910952
fprintf(f, "###\n");
911953
fprintf(f, "### Data starts here\n");
912954

913-
auto obsNames = fDataSet.GetColumnNames();
955+
auto obsNames = fDataFrame.GetColumnNames();
914956
std::string obsListStr = "";
915957
for (const auto& l : obsNames) {
916958
if (!obsListStr.empty()) obsListStr += ":";
@@ -938,7 +980,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
938980

939981
return;
940982
} else if (TRestTools::GetFileNameExtension(filename) == "root") {
941-
fDataSet.Snapshot("AnalysisTree", filename);
983+
fDataFrame.Snapshot("AnalysisTree", filename);
942984

943985
TFile* f = TFile::Open(filename.c_str(), "UPDATE");
944986
std::string name = this->GetName();
@@ -1038,7 +1080,7 @@ void TRestDataSet::Import(const std::string& fileName) {
10381080
else
10391081
ROOT::DisableImplicitMT();
10401082

1041-
fDataSet = ROOT::RDataFrame("AnalysisTree", fileName);
1083+
fDataFrame = ROOT::RDataFrame("AnalysisTree", fileName);
10421084

10431085
fTree = (TChain*)file->Get("AnalysisTree");
10441086
}
@@ -1104,7 +1146,7 @@ void TRestDataSet::Import(std::vector<std::string> fileNames) {
11041146
}
11051147

11061148
RESTInfo << "Opening list of files. First file: " << fileNames[0] << RESTendl;
1107-
fDataSet = ROOT::RDataFrame("AnalysisTree", fileNames);
1149+
fDataFrame = ROOT::RDataFrame("AnalysisTree", fileNames);
11081150

11091151
if (fTree != nullptr) {
11101152
delete fTree;

source/framework/sensitivity/inc/TRestComponentDataSet.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ class TRestComponentDataSet : public TRestComponent {
5454
/// The dataset used to initialize the distribution
5555
TRestDataSet fDataSet; //!
5656

57+
/// It helps to split large datasets when extracting the parameterization nodes
58+
long long unsigned int fSplitEntries = 600000000;
59+
60+
/// It creates a sample subset using a range definition
61+
TVector2 fDFRange = TVector2(0, 0);
62+
5763
/// It is true of the dataset was loaded without issues
5864
Bool_t fDataSetLoaded = false; //!
5965

@@ -84,6 +90,6 @@ class TRestComponentDataSet : public TRestComponent {
8490
TRestComponentDataSet(const char* cfgFileName, const std::string& name);
8591
~TRestComponentDataSet();
8692

87-
ClassDefOverride(TRestComponentDataSet, 3);
93+
ClassDefOverride(TRestComponentDataSet, 4);
8894
};
8995
#endif

source/framework/sensitivity/src/TRestComponentDataSet.cxx

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ void TRestComponentDataSet::PrintMetadata() {
148148
RESTMetadata << " " << RESTendl;
149149
}
150150

151+
if (fDFRange.X() != 0 || fDFRange.Y() != 0) {
152+
RESTMetadata << " DataFrame range: ( " << fDFRange.X() << ", " << fDFRange.Y() << ")" << RESTendl;
153+
RESTMetadata << " " << RESTendl;
154+
}
155+
151156
if (!fParameter.empty() && fParameterizationNodes.empty()) {
152157
RESTMetadata << "This component has no nodes!" << RESTendl;
153158
RESTMetadata << " Use: LoadDataSets() to initialize the nodes" << RESTendl;
@@ -383,15 +388,17 @@ std::vector<Double_t> TRestComponentDataSet::ExtractParameterizationNodes() {
383388
return vs;
384389
}
385390

386-
auto parValues = fDataSet.GetDataFrame().Take<double>(fParameter);
387-
for (const auto v : parValues) vs.push_back(v);
391+
auto GetUniqueElements = [](const std::vector<double>& vec) {
392+
std::set<double> uniqueSet(vec.begin(), vec.end());
393+
return std::vector<double>(uniqueSet.begin(), uniqueSet.end());
394+
};
388395

389-
std::vector<double>::iterator ip;
390-
ip = std::unique(vs.begin(), vs.begin() + vs.size());
391-
vs.resize(std::distance(vs.begin(), ip));
392-
std::sort(vs.begin(), vs.end());
393-
ip = std::unique(vs.begin(), vs.end());
394-
vs.resize(std::distance(vs.begin(), ip));
396+
for (size_t n = 0; n < 1 + fDataSet.GetEntries() / fSplitEntries; n++) {
397+
auto nEn = fDataSet.Range(n * fSplitEntries, (n + 1) * fSplitEntries).Count();
398+
auto parValues = fDataSet.Range(n * fSplitEntries, (n + 1) * fSplitEntries).Take<double>(fParameter);
399+
std::vector<double> uniqueVec = GetUniqueElements(*parValues);
400+
vs.insert(vs.end(), uniqueVec.begin(), uniqueVec.end());
401+
}
395402

396403
return vs;
397404
}
@@ -476,6 +483,9 @@ Bool_t TRestComponentDataSet::LoadDataSets() {
476483
fDataSet.Import(fullFileNames);
477484
fDataSetLoaded = true;
478485

486+
if (fDFRange.X() != 0 || fDFRange.Y() != 0)
487+
fDataSet.ApplyRange((size_t)fDFRange.X(), (size_t)fDFRange.Y());
488+
479489
if (fDataSet.GetTree() == nullptr) {
480490
RESTError << "Problem loading dataset from file list :" << RESTendl;
481491
for (const auto& f : fDataSetFileNames) RESTError << " - " << f << RESTendl;
@@ -486,6 +496,7 @@ Bool_t TRestComponentDataSet::LoadDataSets() {
486496

487497
if (VariablesOk() && WeightsOk()) {
488498
fParameterizationNodes = ExtractParameterizationNodes();
499+
RESTInfo << "Filling histograms" << RESTendl;
489500
FillHistograms();
490501
return fDataSetLoaded;
491502
}
@@ -515,11 +526,12 @@ Bool_t TRestComponentDataSet::WeightsOk() {
515526
Bool_t ok = true;
516527
std::vector cNames = fDataSet.GetDataFrame().GetColumnNames();
517528

518-
for (const auto& var : fWeights)
519-
if (std::count(cNames.begin(), cNames.end(), var) == 0) {
529+
for (const auto& var : fWeights) {
530+
if (!isANumber(var) && std::count(cNames.begin(), cNames.end(), var) == 0) {
520531
RESTError << "Weight ---> " << var << " <--- NOT found on dataset" << RESTendl;
521532
ok = false;
522533
}
534+
}
523535
return ok;
524536
}
525537

0 commit comments

Comments
 (0)