99################################################################################
1010import re
1111import typing
12+
1213from .._numbadeclare import _NumbaDeclareDecorator
1314
15+
1416class FunctionJitter :
1517 """
1618 This class allows to jit a python callable with Numba, being able to infer the signature of the function from the types of the RDF columns.
@@ -38,11 +40,12 @@ def x_greater_than_2(x):
3840 fil1 = rdf.Filter( "Numba::" + func_call, "x is greater than 2")
3941
4042 """
43+
4144 # Variable to store previous functions so as to not rejit.
4245 function_cache = {}
4346 lambda_function_counter = 0 # Counter to name the lambda functions
4447
45- def __init__ (self , rdf : ' RDataFrame' ) -> None :
48+ def __init__ (self , rdf : " RDataFrame" ) -> None :
4649 self .rdf = rdf
4750 self .col_names : typing .List [str ] = rdf .GetColumnNames ()
4851 self .func : typing .Callable
@@ -61,28 +64,29 @@ def find_type(self, x):
6164 Else it is an a numpy array and maps it to corresponding RVec.
6265 Otherwise flags a type error
6366 Args:
64- ` x: Variable whose type is to be determined
67+ ` x: Variable whose type is to be determined
6568
6669 Returns:
6770 type of the variable x
6871 """
6972 try :
7073 import numpy as np
7174 except :
72- raise ImportError (
73- "Failed to import numpy during call to determine function signature." )
74- from . _rdf_conversion_maps import FUNDAMENTAL_PYTHON_TYPES , TREE_TO_NUMBA , NUMPY_TO_TREE
75+ raise ImportError ("Failed to import numpy during call to determine function signature." )
76+ from . _rdf_conversion_maps import FUNDAMENTAL_PYTHON_TYPES , NUMPY_TO_TREE , TREE_TO_NUMBA
77+
7578 if isinstance (x , str ):
7679 # Can be string constant or can be column name
7780 if x in self .col_names : # If x is a column
7881 t = self .rdf .GetColumnType (x )
7982 if t in TREE_TO_NUMBA : # The column is a fundamental type from tree
8083 return TREE_TO_NUMBA [t ]
81- elif '<' in t : # The column type is a RVec<type>
82- if '>>' in t : # It is a RVec<RVec<T>>
84+ elif "<" in t : # The column type is a RVec<type>
85+ if ">>" in t : # It is a RVec<RVec<T>>
8386 raise TypeError (
84- f"Only columns with 'RVec<T>' where T is is a fundamental type are supported, not '{ t } '." )
85- g = re .match ('(.*)<(.*)>' , t ).groups (0 )
87+ f"Only columns with 'RVec<T>' where T is is a fundamental type are supported, not '{ t } '."
88+ )
89+ g = re .match ("(.*)<(.*)>" , t ).groups (0 )
8690 if g [1 ] in TREE_TO_NUMBA :
8791 return "RVec<" + TREE_TO_NUMBA [g [1 ]] + ">"
8892 # There are data type that leak into here. Not sure from where. But need to implement something here such that this condition is never met.
@@ -91,20 +95,18 @@ def find_type(self, x):
9195 else :
9296 return t
9397 else :
94- return ' str'
98+ return " str"
9599 #! Numba Declare does not support "string" type. Check _numbadeclare.Thus, Cannot pass string constants to the filter/Defines..
96100 elif type (x ) in FUNDAMENTAL_PYTHON_TYPES :
97101 return FUNDAMENTAL_PYTHON_TYPES [type (x )]
98102 elif isinstance (x , np .ndarray ):
99103 if x .dtype .type in NUMPY_TO_TREE :
100104 return "RVec<" + NUMPY_TO_TREE [x .dtype .type ] + ">"
101105 else :
102- raise TypeError (
103- f"Support for { x .dtype .type } arrays is not yet supported." )
106+ raise TypeError (f"Support for { x .dtype .type } arrays is not yet supported." )
104107 #! Need to work out how to map things like tuples, dicts, lists...
105108 else :
106- raise TypeError (
107- f"Type of { type (x ).__name__ } : { x } cannot be jitted." )
109+ raise TypeError (f"Type of { type (x ).__name__ } : { x } cannot be jitted." )
108110
109111 def find_function_params (self , func ):
110112 """
@@ -116,6 +118,7 @@ def find_function_params(self, func):
116118
117119 """
118120 import inspect
121+
119122 func_sign = inspect .signature (func )
120123 # Find the Return type
121124 if func_sign .return_annotation is inspect .Signature .empty :
@@ -145,8 +148,11 @@ def generate_func_args(self, cols_list, extra_args):
145148 if n_cols > 0 :
146149 # Check to see if all the parameters have been provided.
147150 if n_params != n_cols + n_constants :
148- raise ValueError ("Not Enough values provided in the column list and extra_args. The function required {} parameters only {} provided." .format (
149- n_params , n_cols + n_constants ))
151+ raise ValueError (
152+ "Not Enough values provided in the column list and extra_args. The function required {} parameters only {} provided." .format (
153+ n_params , n_cols + n_constants
154+ )
155+ )
150156
151157 # Mapping the column list to the first input parameters of the function.
152158 for idx , p in enumerate (self .params ):
@@ -168,12 +174,15 @@ def find_function_signature(self):
168174 type_of_p = self .find_type (value_of_p )
169175 # Bool(s) in python are represented as True/False but in C++ are true/false. The following if statements are to account for that
170176 if type (value_of_p ) == bool :
171- if value_of_p : value_of_p = 'true'
172- else : value_of_p = 'false'
177+ if value_of_p :
178+ value_of_p = "true"
179+ else :
180+ value_of_p = "false"
173181 else : # the parameter was not in func_args. Thus this parameter has to be mapped to a column of rdf
174182 if p not in self .col_names :
175183 raise Exception (
176- f"Unable to map function argument { p } to a column.\n Use correct name of column or pass a list of column names." )
184+ f"Unable to map function argument { p } to a column.\n Use correct name of column or pass a list of column names."
185+ )
177186 value_of_p = p
178187 type_of_p = self .find_type (p )
179188 self .args_info [p ] = (type_of_p , value_of_p )
@@ -184,7 +193,7 @@ def generate_function_call(self):
184193 Updates the class with new attributes func_call and func_sign to hold them,
185194 """
186195 func = self .func
187- if func .__name__ == ' <lambda>' :
196+ if func .__name__ == " <lambda>" :
188197 func .__name__ = f"_lambda_func_number_{ FunctionJitter .lambda_function_counter } "
189198 FunctionJitter .lambda_function_counter += 1
190199 self .func_call = f"{ func .__name__ } ({ ', ' .join (str (arg_info [1 ]) for arg_info in self .args_info .values ())} )"
@@ -214,15 +223,15 @@ def jit_function(self, func, cols_list, extra_args):
214223 func_call , func_sign = FunctionJitter .function_cache [func .__name__ ]
215224 self .get_function_params_args_call (func , cols_list , extra_args )
216225 if self .func_sign != func_sign :
217- raise ValueError (
218- "Trying to re-use a function. Do not change function signature." .format (func ))
226+ raise ValueError ("Trying to re-use a function. Do not change function signature." )
219227 return self .func_call
220228
221229 self .get_function_params_args_call (func , cols_list , extra_args )
222230 _NumbaDeclareDecorator (self .func_sign , self .return_type )(self .func )
223231 FunctionJitter .function_cache [self .func .__name__ ] = (self .func_call , self .func_sign )
224232 return self .func_call
225233
234+
226235def _convert_to_vector (args ):
227236 """
228237 Converts a Python list of strings into an std::vector before passing such
@@ -243,22 +252,23 @@ def _convert_to_vector(args):
243252 return args
244253
245254 try :
246- v = ROOT .std .vector [' std::string' ](args [0 ])
255+ v = ROOT .std .vector [" std::string" ](args [0 ])
247256 except TypeError :
248- raise TypeError (
249- f"The list of columns of a Filter operation can only contain strings. Please check: { args [0 ]} " )
257+ raise TypeError (f"The list of columns of a Filter operation can only contain strings. Please check: { args [0 ]} " )
250258
251259 return (v , * args [1 :])
252260
261+
253262def remove_fn_name_from_signature (signature ):
254263 """
255264 Removes the function name from a signature string.
256265 The signature is expected to be in the form of "return_type function_name(type param1, type param2, ...)".
257266 """
258- if '(' not in signature or ')' not in signature :
267+ if "(" not in signature or ")" not in signature :
259268 raise ValueError (f"Invalid signature format: { signature } " )
260269
261- return signature [:signature .index (' ' )] + signature [signature .index ('(' ):]
270+ return signature [: signature .index (" " )] + signature [signature .index ("(" ) :]
271+
262272
263273def get_cpp_overload_from_templ_proxy (func , types = None ):
264274 """
@@ -269,9 +279,11 @@ def get_cpp_overload_from_templ_proxy(func, types=None):
269279 template_args = func .__template_args__ [1 :- 1 ] if func .__template_args__ else ""
270280 return func .__overload__ (signature , template_args )
271281
282+
272283def get_column_types (rdf , cols ):
273284 return [rdf .GetColumnType (col ) for col in cols ]
274285
286+
275287def get_overload_based_on_args (overload_types , types ):
276288 """
277289 Gets the C++ overloads for a function based on the provided signatures and types.
@@ -287,7 +299,7 @@ def get_overload_based_on_args(overload_types, types):
287299 return next (iter (overload_types ))
288300
289301 for full_sig , info in overload_types .items ():
290- input_types = info .get (' input_types' , ())
302+ input_types = info .get (" input_types" , ())
291303
292304 if len (input_types ) != len (types ):
293305 continue
@@ -302,7 +314,10 @@ def get_overload_based_on_args(overload_types, types):
302314 if match :
303315 return full_sig
304316
305- raise ValueError (f"No matching overload found for types: { types } in signatures: { overload_types .keys ()} . Please check the function overloads." )
317+ raise ValueError (
318+ f"No matching overload found for types: { types } in signatures: { overload_types .keys ()} . Please check the function overloads."
319+ )
320+
306321
307322def _get_cpp_signature (func , rdf , cols ):
308323 """
@@ -320,6 +335,7 @@ def _get_cpp_signature(func, rdf, cols):
320335 matched_overload = get_overload_based_on_args (overload_types , get_column_types (rdf , cols ))
321336 return remove_fn_name_from_signature (matched_overload )
322337
338+
323339def _to_std_function (func , rdf , cols ):
324340 """
325341 Converts a cppyy callable to std::function.
@@ -332,6 +348,7 @@ def _to_std_function(func, rdf, cols):
332348 signature = _get_cpp_signature (func , rdf , cols )
333349 return cppyy .gbl .std .function (signature )
334350
351+
335352def _handle_cpp_callables (func , original_template , * args , rdf = None , cols = None ):
336353 """
337354 Checks whether the callable `func` is a cppyy proxy of one of these:
@@ -360,23 +377,24 @@ def _handle_cpp_callables(func, original_template, *args, rdf=None, cols=None):
360377
361378 import cppyy
362379
363- is_cpp_functor = lambda : isinstance (getattr (func , ' __call__' , None ), cppyy ._backend .CPPOverload )
380+ is_cpp_functor = lambda : isinstance (getattr (func , " __call__" , None ), cppyy ._backend .CPPOverload )
364381
365- is_std_function = lambda : isinstance (getattr (func , ' target_type' , None ), cppyy ._backend .CPPOverload )
382+ is_std_function = lambda : isinstance (getattr (func , " target_type" , None ), cppyy ._backend .CPPOverload )
366383
367384 # handle free functions
368385 if callable (func ) and not is_cpp_functor () and not is_std_function ():
369386 try :
370387 func = _to_std_function (func , rdf , cols )
371388 except TypeError as e :
372389 if "Expected a cppyy callable" in str (e ):
373- pass # this function is not convertible to std::function, move on to the next check
390+ pass # this function is not convertible to std::function, move on to the next check
374391 else :
375392 raise
376393
377394 if is_cpp_functor () or is_std_function ():
378395 return original_template [type (func )](* args )
379396
397+
380398def _PyFilter (rdf , callable_or_str , * args , extra_args = {}):
381399 """
382400 Filters the entries of RDF according to a given condition.
@@ -418,42 +436,47 @@ def x_more_than_y(x):
418436 # The 1st argument is either a string or a python callable.
419437 if not callable (callable_or_str ):
420438 raise TypeError (
421- f"The first argument of a Filter operation should be a callable. { type (callable_or_str ).__name__ } object is not callable." )
439+ f"The first argument of a Filter operation should be a callable. { type (callable_or_str ).__name__ } object is not callable."
440+ )
422441
423442 if len (args ) > 2 :
424- raise TypeError (
425- f"Filter takes at most 3 positional arguments but { len (args ) + 1 } were given" )
443+ raise TypeError (f"Filter takes at most 3 positional arguments but { len (args ) + 1 } were given" )
426444
427445 func = callable_or_str
428446 col_list = []
429- filter_name = ""
430-
447+ filter_name = ""
448+
431449 if len (args ) == 1 :
432- if isinstance (args [0 ], list ):
450+ if isinstance (args [0 ], list ):
433451 col_list = args [0 ]
434452 elif isinstance (args [0 ], str ):
435453 filter_name = args [0 ]
436454 else :
437455 raise ValueError (f"Argument should be either 'list' or 'str', not { type (args [0 ]).__name__ } ." )
438-
456+
439457 elif len (args ) == 2 :
440458 if isinstance (args [0 ], list ) and isinstance (args [1 ], str ):
441459 col_list = args [0 ]
442460 filter_name = args [1 ]
443461 else :
444- raise ValueError (f"Arguments should be ('list', 'str',) not ({ type (args [0 ]).__name__ ,type (args [1 ]).__name__ } ." )
462+ raise ValueError (
463+ f"Arguments should be ('list', 'str',) not ({ type (args [0 ]).__name__ , type (args [1 ]).__name__ } ."
464+ )
445465
446466 rdf_node = _handle_cpp_callables (func , rdf ._OriginalFilter , func , * _convert_to_vector (args ), rdf = rdf , cols = col_list )
447467 if rdf_node is not None :
448468 return rdf_node
449469
450470 jitter = FunctionJitter (rdf )
451- func .__annotations__ ['return' ] = 'bool' # return type for Filters is bool # Note: You can keep double and Filter still works.
452-
471+ func .__annotations__ ["return" ] = (
472+ "bool" # return type for Filters is bool # Note: You can keep double and Filter still works.
473+ )
474+
453475 func_call = jitter .jit_function (func , col_list , extra_args )
454476 return rdf ._OriginalFilter ("Numba::" + func_call , filter_name )
455477
456- def _PyDefine (rdf , col_name , callable_or_str , cols = [] , extra_args = {} ):
478+
479+ def _PyDefine (rdf , col_name , callable_or_str , cols = [], extra_args = {}):
457480 """
458481 Defines a new column in the RDataFrame.
459482 Arguments:
@@ -465,7 +488,7 @@ def _PyDefine(rdf, col_name, callable_or_str, cols = [] , extra_args = {} ):
465488 4. extra_args: non-columnar arguments to be passed to the callable.
466489 Returns:
467490 RDataFrame: rdf with new column defined
468-
491+
469492 Examples:
470493 1. rdf.Define("x", lambda: np.random.rand())
471494 Define using a python lambda
@@ -483,22 +506,26 @@ def x_scaled(x):
483506
484507 """
485508 if not isinstance (col_name , str ):
486- raise TypeError (f"First argument of Define must be a valid string for the new column name. { type (col_name ).__name__ } is not a string." )
509+ raise TypeError (
510+ f"First argument of Define must be a valid string for the new column name. { type (col_name ).__name__ } is not a string."
511+ )
487512
488- if isinstance (callable_or_str , str ): # If string argument is passed. Invoke the Original Define.
513+ if isinstance (callable_or_str , str ): # If string argument is passed. Invoke the Original Define.
489514 return rdf ._OriginalDefine (col_name , callable_or_str )
490515
491- if not callable (callable_or_str ): # The 2st argument is either a string or a python callable.
492- raise TypeError (f"The second argument of a Define operation should be a callable. { type (callable_or_str ).__name__ } object is not callable." )
493-
494- if not isinstance (cols , list ):
516+ if not callable (callable_or_str ): # The 2st argument is either a string or a python callable.
517+ raise TypeError (
518+ f"The second argument of a Define operation should be a callable. { type (callable_or_str ).__name__ } object is not callable."
519+ )
520+
521+ if not isinstance (cols , list ):
495522 raise TypeError (f"Define takes a column list as third arguments but { type (cols ).__name__ } was given." )
496-
523+
497524 func = callable_or_str
498525 rdf_node = _handle_cpp_callables (func , rdf ._OriginalDefine , col_name , func , cols , rdf = rdf , cols = cols )
499526 if rdf_node is not None :
500527 return rdf_node
501528
502- jitter = FunctionJitter (rdf )
529+ jitter = FunctionJitter (rdf )
503530 func_call = jitter .jit_function (func , cols , extra_args )
504531 return rdf ._OriginalDefine (col_name , "Numba::" + func_call )
0 commit comments