Skip to content

Commit 028676a

Browse files
committed
[df] Avoid including ambiguous TTree branch names
When the input data source is a TTree, GetColumnNames gathers the list of all the available TTree branches. In case there are two branches in the tree (e.g. `el1` and `el2`), each of them has a sub-branch with the same name (e.g. `electron_pt`), TTree allows calling `GetBranch("electron_pt")` and returns the pointer to the sub-branch of the first main branch (i.e. `el1.electron_pt`). This behaviour can lead to ambiguities, thus avoid exposing the ambiguous column name via RDF. A test is added to exemplify this case.
1 parent 155fb39 commit 028676a

File tree

5 files changed

+144
-10
lines changed

5 files changed

+144
-10
lines changed

tree/dataframe/src/RTTreeDS.cxx

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ void InsertBranchName(std::set<std::string> &bNamesReg, std::vector<std::string>
102102
foundLeaves.insert(leaf);
103103
}
104104

105-
void ExploreBranch(TTree &t, std::set<std::string> &bNamesReg, std::vector<std::string> &bNames, TBranch *b,
106-
std::string prefix, std::string &friendName, bool allowDuplicates)
105+
void ExploreBranch(TTree &t, std::unordered_map<std::string, unsigned int> &duplicateTokens,
106+
std::set<std::string> &bNamesReg, std::vector<std::string> &bNames, TBranch *b, std::string prefix,
107+
std::string &friendName, bool allowDuplicates)
107108
{
108109
// We want to avoid situations of overlap between the prefix and the
109110
// sub-branch name that might happen when the branch is composite, e.g.
@@ -121,6 +122,15 @@ void ExploreBranch(TTree &t, std::set<std::string> &bNamesReg, std::vector<std::
121122
for (auto sb : *b->GetListOfBranches()) {
122123
TBranch *subBranch = static_cast<TBranch *>(sb);
123124
auto subBranchName = std::string(subBranch->GetName());
125+
126+
// Record names of sub branches, which could reapper in different branch hierarchies of the same dataset. For
127+
// example, the 'Muon' branch could have sub-branch 'pt', as well as the 'Electron' branch. Later we will
128+
// disambiguate by removing the top-level 'pt' branch which TTree doesn't warn about and would end up pointing to
129+
// the first sub-branch encountered in this exploration
130+
if (!duplicateTokens.insert({subBranchName, 1}).second) {
131+
duplicateTokens[subBranchName]++;
132+
}
133+
124134
auto fullName = prefix + subBranchName;
125135

126136
if (auto subNameFirstDot = subBranchName.find_first_of('.'); subNameFirstDot != std::string::npos) {
@@ -133,7 +143,7 @@ void ExploreBranch(TTree &t, std::set<std::string> &bNamesReg, std::vector<std::
133143
if (!prefix.empty())
134144
newPrefix = fullName + ".";
135145

136-
ExploreBranch(t, bNamesReg, bNames, subBranch, newPrefix, friendName, allowDuplicates);
146+
ExploreBranch(t, duplicateTokens, bNamesReg, bNames, subBranch, newPrefix, friendName, allowDuplicates);
137147

138148
auto branchDirectlyFromTree = t.GetBranch(fullName.c_str());
139149
if (!branchDirectlyFromTree)
@@ -147,7 +157,8 @@ void ExploreBranch(TTree &t, std::set<std::string> &bNamesReg, std::vector<std::
147157
}
148158
}
149159

150-
void GetBranchNamesImpl(TTree &t, std::set<std::string> &bNamesReg, std::vector<std::string> &bNames,
160+
void GetBranchNamesImpl(TTree &t, std::unordered_map<std::string, unsigned int> &duplicateTokens,
161+
std::set<std::string> &bNamesReg, std::vector<std::string> &bNames,
151162
std::set<TTree *> &analysedTrees, std::string &friendName, bool allowDuplicates)
152163
{
153164
std::set<TLeaf *> foundLeaves;
@@ -184,7 +195,7 @@ void GetBranchNamesImpl(TTree &t, std::set<std::string> &bNamesReg, std::vector<
184195
}
185196
} else if (branch->IsA() == TBranchObject::Class()) {
186197
// TBranchObject
187-
ExploreBranch(t, bNamesReg, bNames, branch, branchName + ".", friendName, allowDuplicates);
198+
ExploreBranch(t, duplicateTokens, bNamesReg, bNames, branch, branchName + ".", friendName, allowDuplicates);
188199
InsertBranchName(bNamesReg, bNames, branchName, friendName, allowDuplicates);
189200
} else {
190201
// TBranchElement
@@ -199,9 +210,10 @@ void GetBranchNamesImpl(TTree &t, std::set<std::string> &bNamesReg, std::vector<
199210
dotIsImplied = true;
200211

201212
if (dotIsImplied || branchName.back() == '.')
202-
ExploreBranch(t, bNamesReg, bNames, branch, "", friendName, allowDuplicates);
213+
ExploreBranch(t, duplicateTokens, bNamesReg, bNames, branch, "", friendName, allowDuplicates);
203214
else
204-
ExploreBranch(t, bNamesReg, bNames, branch, branchName + ".", friendName, allowDuplicates);
215+
ExploreBranch(t, duplicateTokens, bNamesReg, bNames, branch, branchName + ".", friendName,
216+
allowDuplicates);
205217

206218
InsertBranchName(bNamesReg, bNames, branchName, friendName, allowDuplicates);
207219
}
@@ -226,19 +238,32 @@ void GetBranchNamesImpl(TTree &t, std::set<std::string> &bNamesReg, std::vector<
226238
else
227239
frName = std::string(friendTree->GetName());
228240

229-
GetBranchNamesImpl(*friendTree, bNamesReg, bNames, analysedTrees, frName, allowDuplicates);
241+
GetBranchNamesImpl(*friendTree, duplicateTokens, bNamesReg, bNames, analysedTrees, frName, allowDuplicates);
230242
}
231243
}
232244

233245
///////////////////////////////////////////////////////////////////////////////
234246
/// Get all the branches names, including the ones of the friend trees
235247
std::vector<std::string> RetrieveDatasetSchema(TTree &t, bool allowDuplicates = true)
236248
{
249+
std::unordered_map<std::string, unsigned int> duplicateTokens;
250+
237251
std::set<std::string> bNamesSet;
238252
std::vector<std::string> bNames;
239253
std::set<TTree *> analysedTrees;
240254
std::string emptyFrName = "";
241-
GetBranchNamesImpl(t, bNamesSet, bNames, analysedTrees, emptyFrName, allowDuplicates);
255+
GetBranchNamesImpl(t, duplicateTokens, bNamesSet, bNames, analysedTrees, emptyFrName, allowDuplicates);
256+
257+
// Remove all sub-branches that have duplicate names between different branch hierarchies of the dataset
258+
bNames.erase(std::remove_if(bNames.begin(), bNames.end(),
259+
[&duplicateTokens](const auto &name) {
260+
if (auto it = duplicateTokens.find(name);
261+
it != duplicateTokens.end() && it->second > 1)
262+
return true;
263+
return false;
264+
}),
265+
bNames.end());
266+
242267
return bNames;
243268
}
244269
} // namespace

tree/dataframe/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ ROOT_ADD_GTEST(dataframe_cache dataframe_cache.cxx LIBRARIES ROOTDataFrame)
1313
ROOT_ADD_GTEST(dataframe_callbacks dataframe_callbacks.cxx LIBRARIES ROOTDataFrame)
1414
ROOT_ADD_GTEST(dataframe_cloning dataframe_cloning.cxx LIBRARIES ROOTDataFrame)
1515
ROOT_ADD_GTEST(dataframe_histomodels dataframe_histomodels.cxx LIBRARIES ROOTDataFrame)
16-
ROOT_ADD_GTEST(dataframe_interface dataframe_interface.cxx LIBRARIES ROOTDataFrame)
16+
ROOT_GENERATE_DICTIONARY(SimpleElectronDict ${CMAKE_CURRENT_SOURCE_DIR}/SimpleElectron.hxx LINKDEF ${CMAKE_CURRENT_SOURCE_DIR}/SimpleElectronLinkDef.hxx NO_CXXMODULE DEPENDENCIES RIO)
17+
ROOT_ADD_GTEST(dataframe_interface dataframe_interface.cxx SimpleElectronDict.cxx LIBRARIES ROOTDataFrame)
1718
ROOT_ADD_GTEST(dataframe_nodes dataframe_nodes.cxx LIBRARIES ROOTDataFrame)
1819
ROOT_ADD_GTEST(dataframe_regression dataframe_regression.cxx LIBRARIES Physics ROOTDataFrame GenVector)
1920
ROOT_ADD_GTEST(dataframe_utils dataframe_utils.cxx LIBRARIES ROOTDataFrame)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef ROOT_DATAFRAME_TEST_SIMPLE_ELECTRON
2+
#define ROOT_DATAFRAME_TEST_SIMPLE_ELECTRON
3+
4+
struct SimpleElectron {
5+
float electron_pt{};
6+
};
7+
8+
struct Wrapper {
9+
SimpleElectron electron{};
10+
float electron_pt{};
11+
};
12+
13+
#endif
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#ifdef __CLING__
2+
3+
#pragma link C++ class SimpleElectron + ;
4+
#pragma link C++ class Wrapper + ;
5+
6+
#endif

tree/dataframe/test/dataframe_interface.cxx

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include "gtest/gtest.h"
1515

16+
#include "SimpleElectron.hxx"
17+
1618
#include <thread>
1719

1820
using namespace ROOT;
@@ -994,3 +996,90 @@ TEST(RDataFrameInterface, GetNFilesFromMoreFiles)
994996
ROOT::RDataFrame df{"t", filenames};
995997
EXPECT_EQ(df.GetNFiles(), 3);
996998
}
999+
1000+
void expect_colnames_eq(const std::vector<std::string> &v1, const std::vector<std::string> &v2)
1001+
{
1002+
ASSERT_EQ(v1.size(), v2.size()) << "Vectors 'v1' and 'v2' are of unequal length";
1003+
for (std::size_t i = 0ull; i < v1.size(); ++i) {
1004+
EXPECT_EQ(v1[i], v2[i]) << "Vectors 'v1' and 'v2' differ at index " << i;
1005+
}
1006+
}
1007+
1008+
// https://github.com/root-project/root/issues/19392
1009+
TEST(RDataFrameInterface, GH19392)
1010+
{
1011+
class FileRAII {
1012+
private:
1013+
std::string fPath;
1014+
1015+
public:
1016+
explicit FileRAII(const std::string &path) : fPath(path) {}
1017+
~FileRAII() { std::remove(fPath.c_str()); }
1018+
auto GetPath() const { return fPath.c_str(); }
1019+
};
1020+
1021+
FileRAII fileraii{"dataframe_interface_gh19392.root"};
1022+
const auto treeName{"tree"};
1023+
1024+
{
1025+
auto file = std::make_unique<TFile>(fileraii.GetPath(), "RECREATE");
1026+
auto tree = std::make_unique<TTree>(treeName, treeName);
1027+
1028+
SimpleElectron el1;
1029+
el1.electron_pt = 10.f;
1030+
1031+
SimpleElectron el2;
1032+
el2.electron_pt = 20.f;
1033+
1034+
// The wrapper classes also have the same data member named 'electron_pt'
1035+
// just to make the parsing exercise more complicated
1036+
Wrapper wrap1;
1037+
wrap1.electron_pt = 30.f;
1038+
wrap1.electron.electron_pt = 40.f;
1039+
1040+
Wrapper wrap2;
1041+
wrap2.electron_pt = 50.f;
1042+
wrap2.electron.electron_pt = 60.f;
1043+
1044+
tree->Branch("el1", &el1);
1045+
tree->Branch("el2", &el2);
1046+
tree->Branch("wr1", &wrap1);
1047+
tree->Branch("wr2", &wrap2);
1048+
tree->Fill();
1049+
tree->Write();
1050+
}
1051+
ROOT::RDataFrame df(treeName, fileraii.GetPath());
1052+
const auto columns = df.GetColumnNames();
1053+
// The wrapper classes also have the same data member named 'electron_pt'
1054+
// just to make the parsing exercise more complicated
1055+
// It's thus expected to see e.g. 'wr1.electron_pt'
1056+
const std::vector<std::string> expectedCols{"el1",
1057+
"el1.electron_pt",
1058+
"el2",
1059+
"el2.electron_pt",
1060+
"wr1",
1061+
"wr1.electron",
1062+
"wr1.electron.electron_pt",
1063+
"wr1.electron_pt",
1064+
"wr2",
1065+
"wr2.electron",
1066+
"wr2.electron.electron_pt",
1067+
"wr2.electron_pt"};
1068+
expect_colnames_eq(columns, expectedCols);
1069+
1070+
// Check all values separately, ensures that all full leaf names point to
1071+
// the correct values
1072+
auto el1_pt = df.Take<float>("el1.electron_pt");
1073+
auto el2_pt = df.Take<float>("el2.electron_pt");
1074+
auto wr1_pt = df.Take<float>("wr1.electron_pt");
1075+
auto wr1_el_pt = df.Take<float>("wr1.electron.electron_pt");
1076+
auto wr2_pt = df.Take<float>("wr2.electron_pt");
1077+
auto wr2_el_pt = df.Take<float>("wr2.electron.electron_pt");
1078+
1079+
EXPECT_FLOAT_EQ(el1_pt->at(0), 10.f);
1080+
EXPECT_FLOAT_EQ(el2_pt->at(0), 20.f);
1081+
EXPECT_FLOAT_EQ(wr1_pt->at(0), 30.f);
1082+
EXPECT_FLOAT_EQ(wr1_el_pt->at(0), 40.f);
1083+
EXPECT_FLOAT_EQ(wr2_pt->at(0), 50.f);
1084+
EXPECT_FLOAT_EQ(wr2_el_pt->at(0), 60.f);
1085+
}

0 commit comments

Comments
 (0)