Skip to content

Commit c37c953

Browse files
committed
[Python][RDF] Ruff lint and format
1 parent 31906d5 commit c37c953

File tree

3 files changed

+114
-85
lines changed

3 files changed

+114
-85
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
################################################################################
1010
import re
1111
import typing
12+
1213
from .._numbadeclare import _NumbaDeclareDecorator
1314

15+
1416
class FunctionJitter:
1517
"""
1618
This class allows to jit a python callable with Numba, being able to infer the signature of the function from the types of the RDF columns.
@@ -38,11 +40,12 @@ def x_greater_than_2(x):
3840
fil1 = rdf.Filter( "Numba::" + func_call, "x is greater than 2")
3941
4042
"""
43+
4144
# Variable to store previous functions so as to not rejit.
4245
function_cache = {}
4346
lambda_function_counter = 0 # Counter to name the lambda functions
4447

45-
def __init__(self, rdf: 'RDataFrame') -> None:
48+
def __init__(self, rdf: "RDataFrame") -> None:
4649
self.rdf = rdf
4750
self.col_names: typing.List[str] = rdf.GetColumnNames()
4851
self.func: typing.Callable
@@ -61,28 +64,29 @@ def find_type(self, x):
6164
Else it is an a numpy array and maps it to corresponding RVec.
6265
Otherwise flags a type error
6366
Args:
64-
` x: Variable whose type is to be determined
67+
` x: Variable whose type is to be determined
6568
6669
Returns:
6770
type of the variable x
6871
"""
6972
try:
7073
import numpy as np
7174
except:
72-
raise ImportError(
73-
"Failed to import numpy during call to determine function signature.")
74-
from ._rdf_conversion_maps import FUNDAMENTAL_PYTHON_TYPES, TREE_TO_NUMBA, NUMPY_TO_TREE
75+
raise ImportError("Failed to import numpy during call to determine function signature.")
76+
from ._rdf_conversion_maps import FUNDAMENTAL_PYTHON_TYPES, NUMPY_TO_TREE, TREE_TO_NUMBA
77+
7578
if isinstance(x, str):
7679
# Can be string constant or can be column name
7780
if x in self.col_names: # If x is a column
7881
t = self.rdf.GetColumnType(x)
7982
if t in TREE_TO_NUMBA: # The column is a fundamental type from tree
8083
return TREE_TO_NUMBA[t]
81-
elif '<' in t: # The column type is a RVec<type>
82-
if '>>' in t: # It is a RVec<RVec<T>>
84+
elif "<" in t: # The column type is a RVec<type>
85+
if ">>" in t: # It is a RVec<RVec<T>>
8386
raise TypeError(
84-
f"Only columns with 'RVec<T>' where T is is a fundamental type are supported, not '{t}'.")
85-
g = re.match('(.*)<(.*)>', t).groups(0)
87+
f"Only columns with 'RVec<T>' where T is is a fundamental type are supported, not '{t}'."
88+
)
89+
g = re.match("(.*)<(.*)>", t).groups(0)
8690
if g[1] in TREE_TO_NUMBA:
8791
return "RVec<" + TREE_TO_NUMBA[g[1]] + ">"
8892
# There are data type that leak into here. Not sure from where. But need to implement something here such that this condition is never met.
@@ -91,20 +95,18 @@ def find_type(self, x):
9195
else:
9296
return t
9397
else:
94-
return 'str'
98+
return "str"
9599
#! Numba Declare does not support "string" type. Check _numbadeclare.Thus, Cannot pass string constants to the filter/Defines..
96100
elif type(x) in FUNDAMENTAL_PYTHON_TYPES:
97101
return FUNDAMENTAL_PYTHON_TYPES[type(x)]
98102
elif isinstance(x, np.ndarray):
99103
if x.dtype.type in NUMPY_TO_TREE:
100104
return "RVec<" + NUMPY_TO_TREE[x.dtype.type] + ">"
101105
else:
102-
raise TypeError(
103-
f"Support for {x.dtype.type} arrays is not yet supported.")
106+
raise TypeError(f"Support for {x.dtype.type} arrays is not yet supported.")
104107
#! Need to work out how to map things like tuples, dicts, lists...
105108
else:
106-
raise TypeError(
107-
f"Type of {type(x).__name__}: {x} cannot be jitted.")
109+
raise TypeError(f"Type of {type(x).__name__}: {x} cannot be jitted.")
108110

109111
def find_function_params(self, func):
110112
"""
@@ -116,6 +118,7 @@ def find_function_params(self, func):
116118
117119
"""
118120
import inspect
121+
119122
func_sign = inspect.signature(func)
120123
# Find the Return type
121124
if func_sign.return_annotation is inspect.Signature.empty:
@@ -145,8 +148,11 @@ def generate_func_args(self, cols_list, extra_args):
145148
if n_cols > 0:
146149
# Check to see if all the parameters have been provided.
147150
if n_params != n_cols + n_constants:
148-
raise ValueError("Not Enough values provided in the column list and extra_args. The function required {} parameters only {} provided.".format(
149-
n_params, n_cols+n_constants))
151+
raise ValueError(
152+
"Not Enough values provided in the column list and extra_args. The function required {} parameters only {} provided.".format(
153+
n_params, n_cols + n_constants
154+
)
155+
)
150156

151157
# Mapping the column list to the first input parameters of the function.
152158
for idx, p in enumerate(self.params):
@@ -168,12 +174,15 @@ def find_function_signature(self):
168174
type_of_p = self.find_type(value_of_p)
169175
# Bool(s) in python are represented as True/False but in C++ are true/false. The following if statements are to account for that
170176
if type(value_of_p) == bool:
171-
if value_of_p: value_of_p = 'true'
172-
else: value_of_p = 'false'
177+
if value_of_p:
178+
value_of_p = "true"
179+
else:
180+
value_of_p = "false"
173181
else: # the parameter was not in func_args. Thus this parameter has to be mapped to a column of rdf
174182
if p not in self.col_names:
175183
raise Exception(
176-
f"Unable to map function argument {p} to a column.\nUse correct name of column or pass a list of column names.")
184+
f"Unable to map function argument {p} to a column.\nUse correct name of column or pass a list of column names."
185+
)
177186
value_of_p = p
178187
type_of_p = self.find_type(p)
179188
self.args_info[p] = (type_of_p, value_of_p)
@@ -184,7 +193,7 @@ def generate_function_call(self):
184193
Updates the class with new attributes func_call and func_sign to hold them,
185194
"""
186195
func = self.func
187-
if func.__name__ == '<lambda>':
196+
if func.__name__ == "<lambda>":
188197
func.__name__ = f"_lambda_func_number_{FunctionJitter.lambda_function_counter}"
189198
FunctionJitter.lambda_function_counter += 1
190199
self.func_call = f"{func.__name__}({', '.join(str(arg_info[1]) for arg_info in self.args_info.values())})"
@@ -214,15 +223,15 @@ def jit_function(self, func, cols_list, extra_args):
214223
func_call, func_sign = FunctionJitter.function_cache[func.__name__]
215224
self.get_function_params_args_call(func, cols_list, extra_args)
216225
if self.func_sign != func_sign:
217-
raise ValueError(
218-
"Trying to re-use a function. Do not change function signature.".format(func))
226+
raise ValueError("Trying to re-use a function. Do not change function signature.")
219227
return self.func_call
220228

221229
self.get_function_params_args_call(func, cols_list, extra_args)
222230
_NumbaDeclareDecorator(self.func_sign, self.return_type)(self.func)
223231
FunctionJitter.function_cache[self.func.__name__] = (self.func_call, self.func_sign)
224232
return self.func_call
225233

234+
226235
def _convert_to_vector(args):
227236
"""
228237
Converts a Python list of strings into an std::vector before passing such
@@ -243,22 +252,23 @@ def _convert_to_vector(args):
243252
return args
244253

245254
try:
246-
v = ROOT.std.vector['std::string'](args[0])
255+
v = ROOT.std.vector["std::string"](args[0])
247256
except TypeError:
248-
raise TypeError(
249-
f"The list of columns of a Filter operation can only contain strings. Please check: {args[0]}")
257+
raise TypeError(f"The list of columns of a Filter operation can only contain strings. Please check: {args[0]}")
250258

251259
return (v, *args[1:])
252260

261+
253262
def remove_fn_name_from_signature(signature):
254263
"""
255264
Removes the function name from a signature string.
256265
The signature is expected to be in the form of "return_type function_name(type param1, type param2, ...)".
257266
"""
258-
if '(' not in signature or ')' not in signature:
267+
if "(" not in signature or ")" not in signature:
259268
raise ValueError(f"Invalid signature format: {signature}")
260269

261-
return signature[:signature.index(' ')] + signature[signature.index('('):]
270+
return signature[: signature.index(" ")] + signature[signature.index("(") :]
271+
262272

263273
def get_cpp_overload_from_templ_proxy(func, types=None):
264274
"""
@@ -269,9 +279,11 @@ def get_cpp_overload_from_templ_proxy(func, types=None):
269279
template_args = func.__template_args__[1:-1] if func.__template_args__ else ""
270280
return func.__overload__(signature, template_args)
271281

282+
272283
def get_column_types(rdf, cols):
273284
return [rdf.GetColumnType(col) for col in cols]
274285

286+
275287
def get_overload_based_on_args(overload_types, types):
276288
"""
277289
Gets the C++ overloads for a function based on the provided signatures and types.
@@ -287,7 +299,7 @@ def get_overload_based_on_args(overload_types, types):
287299
return next(iter(overload_types))
288300

289301
for full_sig, info in overload_types.items():
290-
input_types = info.get('input_types', ())
302+
input_types = info.get("input_types", ())
291303

292304
if len(input_types) != len(types):
293305
continue
@@ -302,7 +314,10 @@ def get_overload_based_on_args(overload_types, types):
302314
if match:
303315
return full_sig
304316

305-
raise ValueError(f"No matching overload found for types: {types} in signatures: {overload_types.keys()}. Please check the function overloads.")
317+
raise ValueError(
318+
f"No matching overload found for types: {types} in signatures: {overload_types.keys()}. Please check the function overloads."
319+
)
320+
306321

307322
def _get_cpp_signature(func, rdf, cols):
308323
"""
@@ -320,6 +335,7 @@ def _get_cpp_signature(func, rdf, cols):
320335
matched_overload = get_overload_based_on_args(overload_types, get_column_types(rdf, cols))
321336
return remove_fn_name_from_signature(matched_overload)
322337

338+
323339
def _to_std_function(func, rdf, cols):
324340
"""
325341
Converts a cppyy callable to std::function.
@@ -332,6 +348,7 @@ def _to_std_function(func, rdf, cols):
332348
signature = _get_cpp_signature(func, rdf, cols)
333349
return cppyy.gbl.std.function(signature)
334350

351+
335352
def _handle_cpp_callables(func, original_template, *args, rdf=None, cols=None):
336353
"""
337354
Checks whether the callable `func` is a cppyy proxy of one of these:
@@ -360,23 +377,24 @@ def _handle_cpp_callables(func, original_template, *args, rdf=None, cols=None):
360377

361378
import cppyy
362379

363-
is_cpp_functor = lambda : isinstance(getattr(func, '__call__', None), cppyy._backend.CPPOverload)
380+
is_cpp_functor = lambda: isinstance(getattr(func, "__call__", None), cppyy._backend.CPPOverload)
364381

365-
is_std_function = lambda : isinstance(getattr(func, 'target_type', None), cppyy._backend.CPPOverload)
382+
is_std_function = lambda: isinstance(getattr(func, "target_type", None), cppyy._backend.CPPOverload)
366383

367384
# handle free functions
368385
if callable(func) and not is_cpp_functor() and not is_std_function():
369386
try:
370387
func = _to_std_function(func, rdf, cols)
371388
except TypeError as e:
372389
if "Expected a cppyy callable" in str(e):
373-
pass # this function is not convertible to std::function, move on to the next check
390+
pass # this function is not convertible to std::function, move on to the next check
374391
else:
375392
raise
376393

377394
if is_cpp_functor() or is_std_function():
378395
return original_template[type(func)](*args)
379396

397+
380398
def _PyFilter(rdf, callable_or_str, *args, extra_args={}):
381399
"""
382400
Filters the entries of RDF according to a given condition.
@@ -418,42 +436,47 @@ def x_more_than_y(x):
418436
# The 1st argument is either a string or a python callable.
419437
if not callable(callable_or_str):
420438
raise TypeError(
421-
f"The first argument of a Filter operation should be a callable. {type(callable_or_str).__name__} object is not callable.")
439+
f"The first argument of a Filter operation should be a callable. {type(callable_or_str).__name__} object is not callable."
440+
)
422441

423442
if len(args) > 2:
424-
raise TypeError(
425-
f"Filter takes at most 3 positional arguments but {len(args) + 1} were given")
443+
raise TypeError(f"Filter takes at most 3 positional arguments but {len(args) + 1} were given")
426444

427445
func = callable_or_str
428446
col_list = []
429-
filter_name = ""
430-
447+
filter_name = ""
448+
431449
if len(args) == 1:
432-
if isinstance(args[0], list):
450+
if isinstance(args[0], list):
433451
col_list = args[0]
434452
elif isinstance(args[0], str):
435453
filter_name = args[0]
436454
else:
437455
raise ValueError(f"Argument should be either 'list' or 'str', not {type(args[0]).__name__}.")
438-
456+
439457
elif len(args) == 2:
440458
if isinstance(args[0], list) and isinstance(args[1], str):
441459
col_list = args[0]
442460
filter_name = args[1]
443461
else:
444-
raise ValueError(f"Arguments should be ('list', 'str',) not ({type(args[0]).__name__,type(args[1]).__name__}.")
462+
raise ValueError(
463+
f"Arguments should be ('list', 'str',) not ({type(args[0]).__name__, type(args[1]).__name__}."
464+
)
445465

446466
rdf_node = _handle_cpp_callables(func, rdf._OriginalFilter, func, *_convert_to_vector(args), rdf=rdf, cols=col_list)
447467
if rdf_node is not None:
448468
return rdf_node
449469

450470
jitter = FunctionJitter(rdf)
451-
func.__annotations__['return'] = 'bool' # return type for Filters is bool # Note: You can keep double and Filter still works.
452-
471+
func.__annotations__["return"] = (
472+
"bool" # return type for Filters is bool # Note: You can keep double and Filter still works.
473+
)
474+
453475
func_call = jitter.jit_function(func, col_list, extra_args)
454476
return rdf._OriginalFilter("Numba::" + func_call, filter_name)
455477

456-
def _PyDefine(rdf, col_name, callable_or_str, cols = [] , extra_args = {} ):
478+
479+
def _PyDefine(rdf, col_name, callable_or_str, cols=[], extra_args={}):
457480
"""
458481
Defines a new column in the RDataFrame.
459482
Arguments:
@@ -465,7 +488,7 @@ def _PyDefine(rdf, col_name, callable_or_str, cols = [] , extra_args = {} ):
465488
4. extra_args: non-columnar arguments to be passed to the callable.
466489
Returns:
467490
RDataFrame: rdf with new column defined
468-
491+
469492
Examples:
470493
1. rdf.Define("x", lambda: np.random.rand())
471494
Define using a python lambda
@@ -483,22 +506,26 @@ def x_scaled(x):
483506
484507
"""
485508
if not isinstance(col_name, str):
486-
raise TypeError(f"First argument of Define must be a valid string for the new column name. {type(col_name).__name__} is not a string.")
509+
raise TypeError(
510+
f"First argument of Define must be a valid string for the new column name. {type(col_name).__name__} is not a string."
511+
)
487512

488-
if isinstance(callable_or_str, str): # If string argument is passed. Invoke the Original Define.
513+
if isinstance(callable_or_str, str): # If string argument is passed. Invoke the Original Define.
489514
return rdf._OriginalDefine(col_name, callable_or_str)
490515

491-
if not callable(callable_or_str): # The 2st argument is either a string or a python callable.
492-
raise TypeError(f"The second argument of a Define operation should be a callable. {type(callable_or_str).__name__} object is not callable.")
493-
494-
if not isinstance(cols, list):
516+
if not callable(callable_or_str): # The 2st argument is either a string or a python callable.
517+
raise TypeError(
518+
f"The second argument of a Define operation should be a callable. {type(callable_or_str).__name__} object is not callable."
519+
)
520+
521+
if not isinstance(cols, list):
495522
raise TypeError(f"Define takes a column list as third arguments but {type(cols).__name__} was given.")
496-
523+
497524
func = callable_or_str
498525
rdf_node = _handle_cpp_callables(func, rdf._OriginalDefine, col_name, func, cols, rdf=rdf, cols=cols)
499526
if rdf_node is not None:
500527
return rdf_node
501528

502-
jitter = FunctionJitter(rdf)
529+
jitter = FunctionJitter(rdf)
503530
func_call = jitter.jit_function(func, cols, extra_args)
504531
return rdf._OriginalDefine(col_name, "Numba::" + func_call)

0 commit comments

Comments
 (0)