@@ -113,6 +113,128 @@ def test_std_function(self):
113113 for x ,y in zip (rdf2 .Take ['ULong64_t' ]("rdfentry_" ), rdf2 .Take ['ULong64_t' ]("x" )):
114114 self .assertEqual (x * x , y )
115115
116-
116+ def test_cpp_free_function (self ):
117+ """
118+ Test that a C++ free function can be passed as a callable argument of a
119+ Define operation.
120+ """
121+
122+ test_cases = [
123+ # Free function with arguments
124+ {
125+ "name" : "input_ULong64_t" ,
126+ "decl" : "ULong64_t my_free_function(ULong64_t l) { return l; }" ,
127+ "coltype" : "ULong64_t" ,
128+ "define_args" : ["rdfentry_" ],
129+ "callable" : lambda : ROOT .my_free_function ,
130+ "extract_fn" : lambda x : x ,
131+ "expected_fn" : lambda i : i ,
132+ },
133+ # Free function with user defined struct
134+ {
135+ "name" : "input_user_defined_struct" ,
136+ "decl" : """
137+ struct MyStruct {
138+ ULong64_t value;
139+ };
140+ MyStruct my_free_function_struct(ULong64_t x) {
141+ MyStruct s; s.value = x; return s;
142+ }
143+ """ ,
144+ "coltype" : "MyStruct" ,
145+ "define_args" : ["rdfentry_" ],
146+ "callable" : lambda : ROOT .my_free_function_struct ,
147+ "extract_fn" : lambda s : s .value ,
148+ "expected_fn" : lambda i : i ,
149+ },
150+ # Free function with no arguments
151+ {
152+ "name" : "no_input" ,
153+ "decl" : "ULong64_t my_free_function_none() { return 42; }" ,
154+ "coltype" : "ULong64_t" ,
155+ "define_args" : [],
156+ "callable" : lambda : ROOT .my_free_function_none ,
157+ "extract_fn" : lambda x : x ,
158+ "expected_fn" : lambda _ : 42 ,
159+ },
160+ # Free function with more than one argument
161+ {
162+ "name" : "two_inputs" ,
163+ "decl" : """
164+ struct MyStruct2 {
165+ int value;
166+ };
167+ MyStruct2 my_free_function_two_args(MyStruct2 s, int x) {
168+ s.value = x; return s;
169+ }
170+ """ ,
171+ "coltype" : "MyStruct2" ,
172+ "define_args" : ["s_col" , "int_col" ],
173+ "setup_columns" : {
174+ "s_col" : "MyStruct2()" ,
175+ "int_col" : "(int)rdfentry_"
176+ },
177+ "callable" : lambda : ROOT .my_free_function_two_args ,
178+ "extract_fn" : lambda s : s .value ,
179+ "expected_fn" : lambda i : i ,
180+ }
181+ ]
182+
183+ for case in test_cases :
184+ with self .subTest (case = case ["name" ]):
185+ ROOT .gInterpreter .Declare (case ["decl" ])
186+ rdf = ROOT .RDataFrame (5 )
187+
188+ if "setup_columns" in case :
189+ for colname , gen_fn in case ["setup_columns" ].items ():
190+ rdf = rdf .Define (colname , gen_fn )
191+
192+ rdf = rdf .Define ("new_col" , case ["callable" ](), case .get ("define_args" , []))
193+
194+ outputs = rdf .Take [case ["coltype" ]]("new_col" )
195+ for i , out in enumerate (outputs ):
196+ expected = case ["expected_fn" ](i )
197+ actual = case ["extract_fn" ](out )
198+ self .assertEqual (actual , expected )
199+
200+ def test_cpp_free_function_overload (self ):
201+ """
202+ Test that an overload of a C++ free function can be passed as a callable argument of a
203+ Define operation with overloads.
204+ """
205+
206+ ROOT .gInterpreter .Declare ("""
207+ ULong64_t my_free_function_overload(ULong64_t l) { return l; }
208+ ULong64_t my_free_function_overload(ULong64_t l, ULong64_t m) { return l * m; }
209+ """ )
210+
211+ rdf = ROOT .RDataFrame (5 )
212+ rdf = rdf .Define ("new_col" , ROOT .my_free_function_overload , ["rdfentry_" ])
213+
214+ for x , y in zip (rdf .Take ["ULong64_t" ]("rdfentry_" ), rdf .Take ["ULong64_t" ]("new_col" )):
215+ self .assertEqual (x , y )
216+
217+ rdf = rdf .Define ("new_col_overload" , ROOT .my_free_function_overload , ["rdfentry_" , "rdfentry_" ])
218+ for x , y in zip (rdf .Take ["ULong64_t" ]("rdfentry_" ), rdf .Take ["ULong64_t" ]("new_col_overload" )):
219+ self .assertEqual (x * x , y )
220+
221+ def test_cpp_free_function_template (self ):
222+ """
223+ Test that a templated C++ free function can be passed as a callable argument of a
224+ Define operation.
225+ """
226+
227+ ROOT .gInterpreter .Declare ("""
228+ template <typename T>
229+ T my_free_function_template(T l) { return l; }
230+ """ )
231+
232+ rdf = ROOT .RDataFrame (5 )
233+ rdf = rdf .Define ("new_col" , ROOT .my_free_function_template ["ULong64_t" ], ["rdfentry_" ])
234+
235+ for x , y in zip (rdf .Take ["ULong64_t" ]("rdfentry_" ), rdf .Take ["ULong64_t" ]("new_col" )):
236+ self .assertEqual (x , y )
237+
238+
117239if __name__ == '__main__' :
118240 unittest .main ()
0 commit comments