diff --git a/src/nb_func.cpp b/src/nb_func.cpp index d9fd1550..d2c6efab 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -299,7 +299,7 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept { // Check if the complex dispatch loop is needed bool complex_call = can_mutate_args || has_var_kwargs || has_var_args || - f->nargs >= NB_MAXARGS_SIMPLE; + f->nargs > NB_MAXARGS_SIMPLE; if (has_args) { for (size_t i = is_method; i < f->nargs; ++i) { @@ -690,16 +690,16 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, entries using keyword arguments or default argument values provided in the bindings, if available. - 3. Ensure that either all keyword arguments were "consumed", or that + 2. Ensure that either all keyword arguments were "consumed", or that the function takes a kwargs argument to accept unconsumed kwargs. - 4. Any positional arguments still left get put into a tuple (for args), + 3. Any positional arguments still left get put into a tuple (for args), and any leftover kwargs get put into a dict. - 5. Pack everything into a vector; if we have nb::args or nb::kwargs, they are an - extra tuple or dict at the end of the positional arguments. + 4. Pack everything into a vector; if we have nb::args or nb::kwargs, + they become a tuple or dict at the end of the positional arguments. - 6. Call the function call dispatcher (func_data::impl) + 5. Call the function call dispatcher (func_data::impl) If one of these fail, move on to the next overload and keep trying until we get a result other than NB_NEXT_OVERLOAD. @@ -878,7 +878,8 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, return result; } -/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments +/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments, +/// w/o default arguments, with no more than NB_MAXARGS_SIMPLE arguments, etc. static PyObject *nb_func_vectorcall_simple(PyObject *self, PyObject *const *args_in, size_t nargsf, diff --git a/tests/test_classes.cpp b/tests/test_classes.cpp index cf460a7f..be914f66 100644 --- a/tests/test_classes.cpp +++ b/tests/test_classes.cpp @@ -37,6 +37,9 @@ struct Struct { ~Struct() { destructed++; if (nb::is_alive()) struct_destructed.push_back(i); } int value() const { return i; } + int value_plus(int j, int k, int l, int m, int n, int o, int p) const { + return i + j + k + l + m + n + o + p; + } int getstate() const { ++pickled; return i; } void set_value(int value) { i = value; } void setstate(int value) { unpickled++; i = value; } @@ -163,6 +166,7 @@ NB_MODULE(test_classes_ext, m) { .def(nb::init<>()) .def(nb::init()) .def("value", &Struct::value) + .def("value_plus", &Struct::value_plus) .def("set_value", &Struct::set_value, "value"_a) .def("self", &Struct::self, nb::rv_policy::none) .def("none", [](Struct &) -> const Struct * { return nullptr; }) diff --git a/tests/test_classes.py b/tests/test_classes.py index 42cf12aa..437b01fc 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -59,6 +59,7 @@ def test02_static_overload(): def test03_instantiate(clean): s1: t.Struct = t.Struct() assert s1.value() == 5 + assert s1.value_plus(1, 2, 3, 4, 5, 6, 7) == 33 s2 = t.Struct(10) assert s2.value() == 10 del s1 diff --git a/tests/test_classes_ext.pyi.ref b/tests/test_classes_ext.pyi.ref index e06eb41b..a026b1ca 100644 --- a/tests/test_classes_ext.pyi.ref +++ b/tests/test_classes_ext.pyi.ref @@ -12,6 +12,8 @@ class Struct: def value(self) -> int: ... + def value_plus(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, /) -> int: ... + def set_value(self, value: int) -> None: ... def self(self) -> Struct: ... diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 8104a6d8..3c2010df 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -138,6 +138,12 @@ NB_MODULE(test_functions_ext, m) { return std::make_pair(args.size(), kwargs.size()); }, "a"_a, "b"_a, "myargs"_a, "mykwargs"_a); + /// Function with eight arguments + m.def("test_simple", + [](int i0, int i1, int i2, int i3, int i4, int i5, int i6, int i7) { + return i0 + i1 + i2 + i3 + i4 + i5 + i6 - i7; + }); + /// Test successful/unsuccessful tuple conversion, with rich output types m.def("test_tuple", []() -> nb::typed { return nb::make_tuple("Hello", 123); }); diff --git a/tests/test_functions.py b/tests/test_functions.py index eeae614f..d9da6ea6 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -22,6 +22,7 @@ def test01_capture(): assert t.test_02(5, 3) == 2 assert t.test_03(5, 3) == 44 assert t.test_04() == 60 + assert t.test_simple(0, 1, 2, 3, 4, 5, 6, 7) == 14 def test02_default_args(): diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 64ba9af5..ecf60146 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -41,6 +41,8 @@ def test_07(arg0: int, arg1: int, /, *args, **kwargs) -> tuple[int, int]: ... @overload def test_07(a: int, b: int, *myargs, **mykwargs) -> tuple[int, int]: ... +def test_simple(arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, /) -> int: ... + @overload def test_tuple() -> tuple[str, int]: ...