Skip to content

Commit 0e0f821

Browse files
committed
add tests
1 parent 1e40efd commit 0e0f821

File tree

1 file changed

+102
-1
lines changed

1 file changed

+102
-1
lines changed

bindings/pyroot/pythonizations/test/rdf_define_pyz.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,107 @@ def test_std_function(self):
113113
for x,y in zip(rdf2.Take['ULong64_t']("rdfentry_"), rdf2.Take['ULong64_t']("x")):
114114
self.assertEqual(x*x, y)
115115

116-
116+
def test_cpp_free_function(self):
117+
"""
118+
Test that a C++ free function can be passed as a callable argument of a
119+
Define operation.
120+
"""
121+
122+
test_cases = [
123+
# Free function with arguments
124+
{
125+
"name": "input_ULong64_t",
126+
"decl": """
127+
ULong64_t my_free_function(ULong64_t l) { return l; }
128+
""",
129+
"input": True,
130+
"coltype": "ULong64_t",
131+
"callable": lambda: ROOT.my_free_function,
132+
"extract_fn": lambda x: x,
133+
},
134+
# Free function with user defined struct
135+
{
136+
"name": "input_user_defined_struct",
137+
"decl": """
138+
struct MyStruct {
139+
ULong64_t value;
140+
};
141+
142+
MyStruct my_free_function_struct(ULong64_t x) {
143+
MyStruct s;
144+
s.value = x;
145+
return s;
146+
}
147+
""",
148+
"input": True,
149+
"coltype": "MyStruct",
150+
"callable": lambda: ROOT.my_free_function_struct,
151+
"extract_fn": lambda s: s.value,
152+
},
153+
# Free function with no arguments
154+
{
155+
"name": "no_input",
156+
"decl": """
157+
ULong64_t my_free_function_none() { return 0; }
158+
""",
159+
"input": False,
160+
"coltype": "ULong64_t",
161+
"callable": lambda: ROOT.my_free_function_none,
162+
"extract_fn": lambda x: x,
163+
},
164+
]
165+
166+
for case in test_cases:
167+
with self.subTest(case=case["name"]):
168+
ROOT.gInterpreter.Declare(case["decl"])
169+
170+
rdf = ROOT.RDataFrame(5)
171+
rdf = rdf.Define("new_col", case["callable"](), ["rdfentry_"]) if case["input"] else rdf.Define("new_col", case["callable"]())
172+
173+
inputs = rdf.Take["ULong64_t"]("rdfentry_") if case["input"] else [0] * 5
174+
outputs = rdf.Take[case["coltype"]]("new_col")
175+
176+
for x, y in zip(inputs, outputs):
177+
self.assertEqual(case["extract_fn"](y), x)
178+
179+
def test_cpp_free_function_overloead(self):
180+
"""
181+
Test that an overload of a C++ free function can be passed as a callable argument of a
182+
Define operation with overloads.
183+
"""
184+
185+
ROOT.gInterpreter.Declare("""
186+
ULong64_t my_free_function_overload(ULong64_t l) { return l; }
187+
ULong64_t my_free_function_overload(ULong64_t l, ULong64_t m) { return l * m; }
188+
""")
189+
190+
rdf = ROOT.RDataFrame(5)
191+
rdf = rdf.Define("new_col", ROOT.my_free_function_overload, ["rdfentry_"])
192+
193+
for x, y in zip(rdf.Take["ULong64_t"]("rdfentry_"), rdf.Take["ULong64_t"]("new_col")):
194+
self.assertEqual(x, y)
195+
196+
rdf = rdf.Define("new_col_overload", ROOT.my_free_function_overload, ["rdfentry_", "rdfentry_"])
197+
for x, y in zip(rdf.Take["ULong64_t"]("rdfentry_"), rdf.Take["ULong64_t"]("new_col_overload")):
198+
self.assertEqual(x * x, y)
199+
200+
def test_cpp_free_function_template(self):
201+
"""
202+
Test that a templated C++ free function can be passed as a callable argument of a
203+
Define operation.
204+
"""
205+
206+
ROOT.gInterpreter.Declare("""
207+
template <typename T>
208+
T my_free_function_template(T l) { return l; }
209+
""")
210+
211+
rdf = ROOT.RDataFrame(5)
212+
rdf = rdf.Define("new_col", ROOT.my_free_function_template["ULong64_t"], ["rdfentry_"])
213+
214+
for x, y in zip(rdf.Take["ULong64_t"]("rdfentry_"), rdf.Take["ULong64_t"]("new_col")):
215+
self.assertEqual(x, y)
216+
217+
117218
if __name__ == '__main__':
118219
unittest.main()

0 commit comments

Comments
 (0)