Skip to content

Commit 1b16f28

Browse files
committed
[Python] Extend the RDF define tests with cpp free functions
1 parent ae03573 commit 1b16f28

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

bindings/pyroot/pythonizations/test/rdf_define_pyz.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,128 @@ def test_std_function(self):
113113
for x,y in zip(rdf2.Take['ULong64_t']("rdfentry_"), rdf2.Take['ULong64_t']("x")):
114114
self.assertEqual(x*x, y)
115115

116-
116+
def test_cpp_free_function(self):
117+
"""
118+
Test that a C++ free function can be passed as a callable argument of a
119+
Define operation.
120+
"""
121+
122+
test_cases = [
123+
# Free function with arguments
124+
{
125+
"name": "input_ULong64_t",
126+
"decl": "ULong64_t my_free_function(ULong64_t l) { return l; }",
127+
"coltype": "ULong64_t",
128+
"define_args": ["rdfentry_"],
129+
"callable": lambda: ROOT.my_free_function,
130+
"extract_fn": lambda x: x,
131+
"expected_fn": lambda i: i,
132+
},
133+
# Free function with user defined struct
134+
{
135+
"name": "input_user_defined_struct",
136+
"decl": """
137+
struct MyStruct {
138+
ULong64_t value;
139+
};
140+
MyStruct my_free_function_struct(ULong64_t x) {
141+
MyStruct s; s.value = x; return s;
142+
}
143+
""",
144+
"coltype": "MyStruct",
145+
"define_args": ["rdfentry_"],
146+
"callable": lambda: ROOT.my_free_function_struct,
147+
"extract_fn": lambda s: s.value,
148+
"expected_fn": lambda i: i,
149+
},
150+
# Free function with no arguments
151+
{
152+
"name": "no_input",
153+
"decl": "ULong64_t my_free_function_none() { return 42; }",
154+
"coltype": "ULong64_t",
155+
"define_args": [],
156+
"callable": lambda: ROOT.my_free_function_none,
157+
"extract_fn": lambda x: x,
158+
"expected_fn": lambda _: 42,
159+
},
160+
# Free function with more than one argument
161+
{
162+
"name": "two_inputs",
163+
"decl": """
164+
struct MyStruct2 {
165+
int value;
166+
};
167+
MyStruct2 my_free_function_two_args(MyStruct2 s, int x) {
168+
s.value = x; return s;
169+
}
170+
""",
171+
"coltype": "MyStruct2",
172+
"define_args": ["s_col", "int_col"],
173+
"setup_columns": {
174+
"s_col": "MyStruct2()",
175+
"int_col": "(int)rdfentry_"
176+
},
177+
"callable": lambda: ROOT.my_free_function_two_args,
178+
"extract_fn": lambda s: s.value,
179+
"expected_fn": lambda i: i,
180+
}
181+
]
182+
183+
for case in test_cases:
184+
with self.subTest(case=case["name"]):
185+
ROOT.gInterpreter.Declare(case["decl"])
186+
rdf = ROOT.RDataFrame(5)
187+
188+
if "setup_columns" in case:
189+
for colname, gen_fn in case["setup_columns"].items():
190+
rdf = rdf.Define(colname, gen_fn)
191+
192+
rdf = rdf.Define("new_col", case["callable"](), case.get("define_args", []))
193+
194+
outputs = rdf.Take[case["coltype"]]("new_col")
195+
for i, out in enumerate(outputs):
196+
expected = case["expected_fn"](i)
197+
actual = case["extract_fn"](out)
198+
self.assertEqual(actual, expected)
199+
200+
def test_cpp_free_function_overload(self):
201+
"""
202+
Test that an overload of a C++ free function can be passed as a callable argument of a
203+
Define operation with overloads.
204+
"""
205+
206+
ROOT.gInterpreter.Declare("""
207+
ULong64_t my_free_function_overload(ULong64_t l) { return l; }
208+
ULong64_t my_free_function_overload(ULong64_t l, ULong64_t m) { return l * m; }
209+
""")
210+
211+
rdf = ROOT.RDataFrame(5)
212+
rdf = rdf.Define("new_col", ROOT.my_free_function_overload, ["rdfentry_"])
213+
214+
for x, y in zip(rdf.Take["ULong64_t"]("rdfentry_"), rdf.Take["ULong64_t"]("new_col")):
215+
self.assertEqual(x, y)
216+
217+
rdf = rdf.Define("new_col_overload", ROOT.my_free_function_overload, ["rdfentry_", "rdfentry_"])
218+
for x, y in zip(rdf.Take["ULong64_t"]("rdfentry_"), rdf.Take["ULong64_t"]("new_col_overload")):
219+
self.assertEqual(x * x, y)
220+
221+
def test_cpp_free_function_template(self):
222+
"""
223+
Test that a templated C++ free function can be passed as a callable argument of a
224+
Define operation.
225+
"""
226+
227+
ROOT.gInterpreter.Declare("""
228+
template <typename T>
229+
T my_free_function_template(T l) { return l; }
230+
""")
231+
232+
rdf = ROOT.RDataFrame(5)
233+
rdf = rdf.Define("new_col", ROOT.my_free_function_template["ULong64_t"], ["rdfentry_"])
234+
235+
for x, y in zip(rdf.Take["ULong64_t"]("rdfentry_"), rdf.Take["ULong64_t"]("new_col")):
236+
self.assertEqual(x, y)
237+
238+
117239
if __name__ == '__main__':
118240
unittest.main()

bindings/pyroot/pythonizations/test/rdf_filter_pyz.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,60 @@ def test_std_function(self):
139139

140140
self.assertEqual(c, 1)
141141

142+
def test_cpp_free_function(self):
143+
"""
144+
Test that a C++ free function can be passed as a callable argument of a
145+
Filter operation.
146+
"""
147+
148+
ROOT.gInterpreter.Declare(
149+
"""
150+
bool myfun(ULong64_t l) { return l == 0; }
151+
"""
152+
)
153+
154+
rdf = ROOT.RDataFrame(5)
155+
c = rdf.Filter(ROOT.myfun, ["rdfentry_"]).Count().GetValue()
156+
157+
self.assertEqual(c, 1)
158+
159+
def test_cpp_free_function_overload(self):
160+
"""
161+
Test that an overload of a C++ free function can be passed as a callable argument of a
162+
Filter operation with overloads.
163+
"""
164+
165+
ROOT.gInterpreter.Declare(
166+
"""
167+
bool myfun(ULong64_t l) { return l == 0; }
168+
bool myfun(int l) { return true; }
169+
"""
170+
)
171+
172+
rdf = ROOT.RDataFrame(5)
173+
c = rdf.Filter(ROOT.myfun, ["rdfentry_"]).Count().GetValue()
174+
175+
self.assertEqual(c, 1)
176+
177+
def test_cpp_free_function_template(self):
178+
"""
179+
Test that a C++ free function template can be passed as a callable argument of a
180+
Filter operation.
181+
"""
182+
183+
ROOT.gInterpreter.Declare(
184+
"""
185+
template <typename T>
186+
bool myfun_t(T l) { return l == 0; }
187+
"""
188+
)
189+
190+
rdf2 = ROOT.RDataFrame(5)
191+
c = rdf2.Define("x", "(int) rdfentry_") \
192+
.Filter(ROOT.myfun_t[int], ["x"]).Count().GetValue()
193+
194+
self.assertEqual(c, 1)
195+
142196

143197
if __name__ == "__main__":
144198
unittest.main()

0 commit comments

Comments
 (0)