@@ -113,6 +113,107 @@ def test_std_function(self):
113113 for x ,y in zip (rdf2 .Take ['ULong64_t' ]("rdfentry_" ), rdf2 .Take ['ULong64_t' ]("x" )):
114114 self .assertEqual (x * x , y )
115115
116-
116+ def test_cpp_free_function (self ):
117+ """
118+ Test that a C++ free function can be passed as a callable argument of a
119+ Define operation.
120+ """
121+
122+ test_cases = [
123+ # Free function with arguments
124+ {
125+ "name" : "input_ULong64_t" ,
126+ "decl" : """
127+ ULong64_t my_free_function(ULong64_t l) { return l; }
128+ """ ,
129+ "input" : True ,
130+ "coltype" : "ULong64_t" ,
131+ "callable" : lambda : ROOT .my_free_function ,
132+ "extract_fn" : lambda x : x ,
133+ },
134+ # Free function with user defined struct
135+ {
136+ "name" : "input_user_defined_struct" ,
137+ "decl" : """
138+ struct MyStruct {
139+ ULong64_t value;
140+ };
141+
142+ MyStruct my_free_function_struct(ULong64_t x) {
143+ MyStruct s;
144+ s.value = x;
145+ return s;
146+ }
147+ """ ,
148+ "input" : True ,
149+ "coltype" : "MyStruct" ,
150+ "callable" : lambda : ROOT .my_free_function_struct ,
151+ "extract_fn" : lambda s : s .value ,
152+ },
153+ # Free function with no arguments
154+ {
155+ "name" : "no_input" ,
156+ "decl" : """
157+ ULong64_t my_free_function_none() { return 0; }
158+ """ ,
159+ "input" : False ,
160+ "coltype" : "ULong64_t" ,
161+ "callable" : lambda : ROOT .my_free_function_none ,
162+ "extract_fn" : lambda x : x ,
163+ },
164+ ]
165+
166+ for case in test_cases :
167+ with self .subTest (case = case ["name" ]):
168+ ROOT .gInterpreter .Declare (case ["decl" ])
169+
170+ rdf = ROOT .RDataFrame (5 )
171+ rdf = rdf .Define ("new_col" , case ["callable" ](), ["rdfentry_" ]) if case ["input" ] else rdf .Define ("new_col" , case ["callable" ]())
172+
173+ inputs = rdf .Take ["ULong64_t" ]("rdfentry_" ) if case ["input" ] else [0 ] * 5
174+ outputs = rdf .Take [case ["coltype" ]]("new_col" )
175+
176+ for x , y in zip (inputs , outputs ):
177+ self .assertEqual (case ["extract_fn" ](y ), x )
178+
179+ def test_cpp_free_function_overloead (self ):
180+ """
181+ Test that an overload of a C++ free function can be passed as a callable argument of a
182+ Define operation with overloads.
183+ """
184+
185+ ROOT .gInterpreter .Declare ("""
186+ ULong64_t my_free_function_overload(ULong64_t l) { return l; }
187+ ULong64_t my_free_function_overload(ULong64_t l, ULong64_t m) { return l * m; }
188+ """ )
189+
190+ rdf = ROOT .RDataFrame (5 )
191+ rdf = rdf .Define ("new_col" , ROOT .my_free_function_overload , ["rdfentry_" ])
192+
193+ for x , y in zip (rdf .Take ["ULong64_t" ]("rdfentry_" ), rdf .Take ["ULong64_t" ]("new_col" )):
194+ self .assertEqual (x , y )
195+
196+ rdf = rdf .Define ("new_col_overload" , ROOT .my_free_function_overload , ["rdfentry_" , "rdfentry_" ])
197+ for x , y in zip (rdf .Take ["ULong64_t" ]("rdfentry_" ), rdf .Take ["ULong64_t" ]("new_col_overload" )):
198+ self .assertEqual (x * x , y )
199+
200+ def test_cpp_free_function_template (self ):
201+ """
202+ Test that a templated C++ free function can be passed as a callable argument of a
203+ Define operation.
204+ """
205+
206+ ROOT .gInterpreter .Declare ("""
207+ template <typename T>
208+ T my_free_function_template(T l) { return l; }
209+ """ )
210+
211+ rdf = ROOT .RDataFrame (5 )
212+ rdf = rdf .Define ("new_col" , ROOT .my_free_function_template ["ULong64_t" ], ["rdfentry_" ])
213+
214+ for x , y in zip (rdf .Take ["ULong64_t" ]("rdfentry_" ), rdf .Take ["ULong64_t" ]("new_col" )):
215+ self .assertEqual (x , y )
216+
217+
117218if __name__ == '__main__' :
118219 unittest .main ()
0 commit comments