Skip to content

Commit 1e40efd

Browse files
committed
add support for method overloads
1 parent 30b8c98 commit 1e40efd

File tree

1 file changed

+37
-10
lines changed
  • bindings/pyroot/pythonizations/python/ROOT/_pythonization

1 file changed

+37
-10
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,31 @@ def get_column_types(rdf, cols):
269269
types = [rdf.GetColumnType(col) for col in cols]
270270
return types
271271

272+
def get_overload_based_on_args(signatures, types):
273+
"""
274+
Gets the C++ overloads for a function based on the provided signatures and types.
275+
"""
276+
if not isinstance(signatures, list) or not all(isinstance(sig, str) for sig in signatures):
277+
raise TypeError("Function signatures must be a list of strings.")
278+
279+
if not isinstance(types, list) or not all(isinstance(t, str) for t in types):
280+
raise TypeError("Types must be a list of strings.")
281+
282+
if len(signatures) == 1:
283+
# If there is only one signature, return it directly
284+
return signatures[0]
285+
286+
for sig in signatures:
287+
args_str = sig[sig.index('(') + 1:sig.rindex(')')]
288+
289+
if not args_str and not types:
290+
return sig
291+
292+
args = [arg.strip() for arg in args_str.split(',') if arg.strip()]
293+
294+
if len(args) == len(types):
295+
return sig
296+
272297
def _get_cpp_signature(func, rdf, cols):
273298
"""
274299
Gets the C++ signature of a cppyy callable.
@@ -286,12 +311,7 @@ def _get_cpp_signature(func, rdf, cols):
286311
raise ValueError(f"Function {func} has no `__overload__` method, cannot deduce signature.")
287312

288313
signatures = [remove_fn_name_from_signature(sig) for sig in get_overloads()]
289-
if len(signatures) != 1:
290-
raise NotImplementedError(
291-
f"Function {func} has multiple overloads :)"
292-
)
293-
294-
return signatures[0]
314+
return get_overload_based_on_args(signatures, get_column_types(rdf, cols))
295315

296316
def _to_std_function(func, rdf, cols):
297317
"""
@@ -310,6 +330,7 @@ def _handle_cpp_callables(func, original_template, *args, rdf=None, cols=None):
310330
Checks whether the callable `func` is a cppyy proxy of one of these:
311331
1. C++ functor
312332
2. std::function
333+
3. C++ free function
313334
314335
The cases above are supported by cppyy, so we can just invoke the original
315336
cppyy TemplateProxy (Filter or Define) with the callable as argument.
@@ -332,12 +353,18 @@ def _handle_cpp_callables(func, original_template, *args, rdf=None, cols=None):
332353

333354
is_cpp_functor = lambda : isinstance(getattr(func, '__call__', None), cppyy._backend.CPPOverload)
334355

335-
# handle free functions
336-
if callable(func) and not is_cpp_functor():
337-
func = _to_std_function(func, rdf, cols)
338-
339356
is_std_function = lambda : isinstance(getattr(func, 'target_type', None), cppyy._backend.CPPOverload)
340357

358+
# handle free functions
359+
if callable(func) and not is_cpp_functor() and not is_std_function():
360+
try:
361+
func = _to_std_function(func, rdf, cols)
362+
except TypeError as e:
363+
if "Expected a cppyy callable" in str(e):
364+
pass # this function is not convertible to std::function, move on to the next check
365+
else:
366+
raise
367+
341368
if is_cpp_functor() or is_std_function():
342369
return original_template[type(func)](*args)
343370

0 commit comments

Comments
 (0)