Skip to content

Commit e507b11

Browse files
authored
Allow simple function dispatch if nargs == NB_MAXARGS_SIMPLE (#1193)
1 parent 9c414f5 commit e507b11

File tree

7 files changed

+24
-7
lines changed

7 files changed

+24
-7
lines changed

src/nb_func.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ PyObject *nb_func_new(const func_data_prelim_base *f) noexcept {
299299

300300
// Check if the complex dispatch loop is needed
301301
bool complex_call = can_mutate_args || has_var_kwargs || has_var_args ||
302-
f->nargs >= NB_MAXARGS_SIMPLE;
302+
f->nargs > NB_MAXARGS_SIMPLE;
303303

304304
if (has_args) {
305305
for (size_t i = is_method; i < f->nargs; ++i) {
@@ -690,16 +690,16 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
690690
entries using keyword arguments or default argument values provided
691691
in the bindings, if available.
692692
693-
3. Ensure that either all keyword arguments were "consumed", or that
693+
2. Ensure that either all keyword arguments were "consumed", or that
694694
the function takes a kwargs argument to accept unconsumed kwargs.
695695
696-
4. Any positional arguments still left get put into a tuple (for args),
696+
3. Any positional arguments still left get put into a tuple (for args),
697697
and any leftover kwargs get put into a dict.
698698
699-
5. Pack everything into a vector; if we have nb::args or nb::kwargs, they are an
700-
extra tuple or dict at the end of the positional arguments.
699+
4. Pack everything into a vector; if we have nb::args or nb::kwargs,
700+
they become a tuple or dict at the end of the positional arguments.
701701
702-
6. Call the function call dispatcher (func_data::impl)
702+
5. Call the function call dispatcher (func_data::impl)
703703
704704
If one of these fail, move on to the next overload and keep trying
705705
until we get a result other than NB_NEXT_OVERLOAD.
@@ -878,7 +878,8 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
878878
return result;
879879
}
880880

881-
/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments
881+
/// Simplified nb_func_vectorcall variant for functions w/o keyword arguments,
882+
/// w/o default arguments, with no more than NB_MAXARGS_SIMPLE arguments, etc.
882883
static PyObject *nb_func_vectorcall_simple(PyObject *self,
883884
PyObject *const *args_in,
884885
size_t nargsf,

tests/test_classes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ struct Struct {
3737
~Struct() { destructed++; if (nb::is_alive()) struct_destructed.push_back(i); }
3838

3939
int value() const { return i; }
40+
int value_plus(int j, int k, int l, int m, int n, int o, int p) const {
41+
return i + j + k + l + m + n + o + p;
42+
}
4043
int getstate() const { ++pickled; return i; }
4144
void set_value(int value) { i = value; }
4245
void setstate(int value) { unpickled++; i = value; }
@@ -163,6 +166,7 @@ NB_MODULE(test_classes_ext, m) {
163166
.def(nb::init<>())
164167
.def(nb::init<int>())
165168
.def("value", &Struct::value)
169+
.def("value_plus", &Struct::value_plus)
166170
.def("set_value", &Struct::set_value, "value"_a)
167171
.def("self", &Struct::self, nb::rv_policy::none)
168172
.def("none", [](Struct &) -> const Struct * { return nullptr; })

tests/test_classes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test02_static_overload():
5959
def test03_instantiate(clean):
6060
s1: t.Struct = t.Struct()
6161
assert s1.value() == 5
62+
assert s1.value_plus(1, 2, 3, 4, 5, 6, 7) == 33
6263
s2 = t.Struct(10)
6364
assert s2.value() == 10
6465
del s1

tests/test_classes_ext.pyi.ref

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ class Struct:
1212

1313
def value(self) -> int: ...
1414

15+
def value_plus(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, /) -> int: ...
16+
1517
def set_value(self, value: int) -> None: ...
1618

1719
def self(self) -> Struct: ...

tests/test_functions.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ NB_MODULE(test_functions_ext, m) {
138138
return std::make_pair(args.size(), kwargs.size());
139139
}, "a"_a, "b"_a, "myargs"_a, "mykwargs"_a);
140140

141+
/// Function with eight arguments
142+
m.def("test_simple",
143+
[](int i0, int i1, int i2, int i3, int i4, int i5, int i6, int i7) {
144+
return i0 + i1 + i2 + i3 + i4 + i5 + i6 - i7;
145+
});
146+
141147
/// Test successful/unsuccessful tuple conversion, with rich output types
142148
m.def("test_tuple", []() -> nb::typed<nb::tuple, std::string, int> {
143149
return nb::make_tuple("Hello", 123); });

tests/test_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test01_capture():
2222
assert t.test_02(5, 3) == 2
2323
assert t.test_03(5, 3) == 44
2424
assert t.test_04() == 60
25+
assert t.test_simple(0, 1, 2, 3, 4, 5, 6, 7) == 14
2526

2627

2728
def test02_default_args():

tests/test_functions_ext.pyi.ref

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def test_07(arg0: int, arg1: int, /, *args, **kwargs) -> tuple[int, int]: ...
4141
@overload
4242
def test_07(a: int, b: int, *myargs, **mykwargs) -> tuple[int, int]: ...
4343

44+
def test_simple(arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, /) -> int: ...
45+
4446
@overload
4547
def test_tuple() -> tuple[str, int]: ...
4648

0 commit comments

Comments
 (0)