@@ -269,30 +269,37 @@ def get_column_types(rdf, cols):
269269 types = [rdf .GetColumnType (col ) for col in cols ]
270270 return types
271271
272- def get_overload_based_on_args (signatures , types ):
272+ def get_overload_based_on_args (overload_types , types ):
273273 """
274274 Gets the C++ overloads for a function based on the provided signatures and types.
275275 """
276- if not isinstance (signatures , list ) or not all ( isinstance ( sig , str ) for sig in signatures ):
277- raise TypeError ("Function signatures must be a list of strings ." )
276+ if not isinstance (overload_types , dict ):
277+ raise TypeError ("Overload types must be a dictionary ." )
278278
279279 if not isinstance (types , list ) or not all (isinstance (t , str ) for t in types ):
280280 raise TypeError ("Types must be a list of strings." )
281281
282- if len (signatures ) == 1 :
282+ if len (overload_types ) == 1 :
283283 # If there is only one signature, return it directly
284- return signatures [ 0 ]
284+ return next ( iter ( overload_types ))
285285
286- for sig in signatures :
287- args_str = sig [ sig . index ( '(' ) + 1 : sig . rindex ( ')' )]
286+ for full_sig , info in overload_types . items () :
287+ input_types = info . get ( 'input_types' , ())
288288
289- if not args_str and not types :
290- return sig
289+ if len ( input_types ) != len ( types ) :
290+ continue
291291
292- args = [arg .strip () for arg in args_str .split (',' ) if arg .strip ()]
292+ match = True
293+ for expected , actual in zip (input_types , types ):
294+ # Loose match: require actual to appear somewhere in expected (e.g. "int" in "const int&")
295+ if actual not in expected :
296+ match = False
297+ break
293298
294- if len (args ) == len (types ):
295- return sig
299+ if match :
300+ return full_sig
301+
302+ raise ValueError (f"No matching overload found for types: { types } in signatures: { overload_types .keys ()} . Please check the function overloads." )
296303
297304def _get_cpp_signature (func , rdf , cols ):
298305 """
@@ -306,12 +313,9 @@ def _get_cpp_signature(func, rdf, cols):
306313 if not isinstance (func , cppyy ._backend .CPPOverload ):
307314 raise TypeError (f"Expected a cppyy callable, got { type (func ).__name__ } " )
308315
309- get_overloads = func .__overload__
310- if not get_overloads :
311- raise ValueError (f"Function { func } has no `__overload__` method, cannot deduce signature." )
312-
313- signatures = [remove_fn_name_from_signature (sig ) for sig in get_overloads ()]
314- return get_overload_based_on_args (signatures , get_column_types (rdf , cols ))
316+ overload_types = func .func_overloads_types
317+ matched_overload = get_overload_based_on_args (overload_types , get_column_types (rdf , cols ))
318+ return remove_fn_name_from_signature (matched_overload )
315319
316320def _to_std_function (func , rdf , cols ):
317321 """
0 commit comments