Skip to content

Commit b7a30a5

Browse files
committed
updating to use the new cppyy patch of func_overloads_types
1 parent 0e0f821 commit b7a30a5

File tree

1 file changed

+22
-18
lines changed
  • bindings/pyroot/pythonizations/python/ROOT/_pythonization

1 file changed

+22
-18
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_rdf_pyz.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,30 +269,37 @@ def get_column_types(rdf, cols):
269269
types = [rdf.GetColumnType(col) for col in cols]
270270
return types
271271

272-
def get_overload_based_on_args(signatures, types):
272+
def get_overload_based_on_args(overload_types, types):
273273
"""
274274
Gets the C++ overloads for a function based on the provided signatures and types.
275275
"""
276-
if not isinstance(signatures, list) or not all(isinstance(sig, str) for sig in signatures):
277-
raise TypeError("Function signatures must be a list of strings.")
276+
if not isinstance(overload_types, dict):
277+
raise TypeError("Overload types must be a dictionary.")
278278

279279
if not isinstance(types, list) or not all(isinstance(t, str) for t in types):
280280
raise TypeError("Types must be a list of strings.")
281281

282-
if len(signatures) == 1:
282+
if len(overload_types) == 1:
283283
# If there is only one signature, return it directly
284-
return signatures[0]
284+
return next(iter(overload_types))
285285

286-
for sig in signatures:
287-
args_str = sig[sig.index('(') + 1:sig.rindex(')')]
286+
for full_sig, info in overload_types.items():
287+
input_types = info.get('input_types', ())
288288

289-
if not args_str and not types:
290-
return sig
289+
if len(input_types) != len(types):
290+
continue
291291

292-
args = [arg.strip() for arg in args_str.split(',') if arg.strip()]
292+
match = True
293+
for expected, actual in zip(input_types, types):
294+
# Loose match: require actual to appear somewhere in expected (e.g. "int" in "const int&")
295+
if actual not in expected:
296+
match = False
297+
break
293298

294-
if len(args) == len(types):
295-
return sig
299+
if match:
300+
return full_sig
301+
302+
raise ValueError(f"No matching overload found for types: {types} in signatures: {overload_types.keys()}. Please check the function overloads.")
296303

297304
def _get_cpp_signature(func, rdf, cols):
298305
"""
@@ -306,12 +313,9 @@ def _get_cpp_signature(func, rdf, cols):
306313
if not isinstance(func, cppyy._backend.CPPOverload):
307314
raise TypeError(f"Expected a cppyy callable, got {type(func).__name__}")
308315

309-
get_overloads = func.__overload__
310-
if not get_overloads:
311-
raise ValueError(f"Function {func} has no `__overload__` method, cannot deduce signature.")
312-
313-
signatures = [remove_fn_name_from_signature(sig) for sig in get_overloads()]
314-
return get_overload_based_on_args(signatures, get_column_types(rdf, cols))
316+
overload_types = func.func_overloads_types
317+
matched_overload = get_overload_based_on_args(overload_types, get_column_types(rdf, cols))
318+
return remove_fn_name_from_signature(matched_overload)
315319

316320
def _to_std_function(func, rdf, cols):
317321
"""

0 commit comments

Comments
 (0)