Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmake/nanobind-config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ endfunction()
# ---------------------------------------------------------------------------

function (nanobind_add_stub name)
cmake_parse_arguments(PARSE_ARGV 1 ARG "VERBOSE;INCLUDE_PRIVATE;EXCLUDE_DOCSTRINGS;INSTALL_TIME;RECURSIVE;EXCLUDE_FROM_ALL" "MODULE;COMPONENT;PATTERN_FILE;OUTPUT_PATH" "PYTHON_PATH;DEPENDS;MARKER_FILE;OUTPUT")
cmake_parse_arguments(PARSE_ARGV 1 ARG "VERBOSE;INCLUDE_PRIVATE;EXCLUDE_DOCSTRINGS;EXCLUDE_VALUES;INSTALL_TIME;RECURSIVE;EXCLUDE_FROM_ALL" "MODULE;COMPONENT;PATTERN_FILE;OUTPUT_PATH" "PYTHON_PATH;DEPENDS;MARKER_FILE;OUTPUT")

if (EXISTS ${NB_DIR}/src/stubgen.py)
set(NB_STUBGEN "${NB_DIR}/src/stubgen.py")
Expand All @@ -614,6 +614,10 @@ function (nanobind_add_stub name)
list(APPEND NB_STUBGEN_ARGS -D)
endif()

if (ARG_EXCLUDE_VALUES)
list(APPEND NB_STUBGEN_ARGS --exclude-values)
endif()

if (ARG_RECURSIVE)
list(APPEND NB_STUBGEN_ARGS -r)
endif()
Expand Down
3 changes: 2 additions & 1 deletion docs/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ The program has the following command line options:
.. code-block:: text

usage: python -m nanobind.stubgen [-h] [-o FILE] [-O PATH] [-i PATH] [-m MODULE]
[-r] [-M FILE] [-P] [-D] [-q]
[-r] [-M FILE] [-P] [-D] [--exclude-values] [-q]

Generate stubs for nanobind-based extensions.

Expand All @@ -559,6 +559,7 @@ The program has the following command line options:
-P, --include-private include private members (with single leading or
trailing underscore)
-D, --exclude-docstrings exclude docstrings from the generated stub
--exclude-values force the use of ... for values
-q, --quiet do not generate any output in the absence of failures


Expand Down
27 changes: 21 additions & 6 deletions src/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,12 +1004,14 @@ def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]:
"""
tp = type(e)
if issubclass(tp, (bool, int, type(None), type(builtins.Ellipsis))):
return repr(e)
s = repr(e)
if len(s) < self.max_expr_length or not abbrev:
return s
elif issubclass(tp, float):
s = repr(e)
if "inf" in s or "nan" in s:
return f"float('{s}')"
else:
s = f"float('{s}')"
if len(s) < self.max_expr_length or not abbrev:
return s
elif issubclass(tp, type) or typing.get_origin(e):
return self.type_str(e)
Expand All @@ -1025,13 +1027,17 @@ def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]:
tv = self.import_object("typing", "TypeVar")
s = f'{tv}("{e.__name__}"'
for v in getattr(e, "__constraints__", ()):
v = self.expr_str(v)
v = self.type_str(v)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to type_str here and below because both constraints and bound are type expressions as per the typing spec.

assert v
s += ", " + v
for k in ["contravariant", "covariant", "bound", "infer_variance"]:
if v := getattr(e, "__bound__", None):
v = self.type_str(v)
assert v
s += ", bound=" + v
for k in ["contravariant", "covariant", "infer_variance"]:
v = getattr(e, f"__{k}__", None)
if v:
v = self.expr_str(v)
v = self.expr_str(v, abbrev=False)
if v is None:
return None
s += f", {k}=" + v
Expand Down Expand Up @@ -1319,6 +1325,14 @@ def parse_options(args: List[str]) -> argparse.Namespace:
help="exclude docstrings from the generated stub",
)

parser.add_argument(
"--exclude-values",
dest="exclude_values",
default=False,
action="store_true",
help="force the use of ... for values",
)

parser.add_argument(
"-q",
"--quiet",
Expand Down Expand Up @@ -1463,6 +1477,7 @@ def main(args: Optional[List[str]] = None) -> None:
recursive=opt.recursive,
include_docstrings=opt.include_docstrings,
include_private=opt.include_private,
max_expr_length=0 if opt.exclude_values else 50,
patterns=patterns,
output_file=file
)
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ foreach (NAME functions classes ndarray jax tensorflow stl enum typing make_iter
set(EXTRA
MARKER_FILE py.typed
PATTERN_FILE "${CMAKE_CURRENT_SOURCE_DIR}/pattern_file.nb"
EXCLUDE_VALUES
)
set(EXTRA_DEPENDS "${OUT_DIR}/py_stub_test.py")
else()
Expand Down
4 changes: 4 additions & 0 deletions tests/test_typing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ NB_MODULE(test_typing_ext, m) {
m.def("list_front", [](nb::list l) { return l[0]; },
nb::sig("def list_front[T](arg: list[T], /) -> T"));

// Type variables with constraints and a bound.
m.attr("T2") = nb::type_var("T2", "bound"_a = nb::type<Foo>());
m.attr("T3") = nb::type_var("T3", *nb::make_tuple(nb::type<Foo>(), nb::type<Wrapper>()));

// Some statements that will be modified by the pattern file
m.def("remove_me", []{});
m.def("tweak_me", [](nb::object o) { return o; }, "prior docstring\nremains preserved");
Expand Down
6 changes: 5 additions & 1 deletion tests/test_typing_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class CustomSignature(Iterable[int]):
def value(self, value: Optional[int], /) -> None:
"""docstring for setter"""

pytree: dict = {'a' : ('b', [123])}
pytree: dict = ...

T = TypeVar("T", contravariant=True)

Expand All @@ -63,6 +63,10 @@ class WrapperTypeParam[T]:

def list_front[T](arg: list[T], /) -> T: ...

T2 = TypeVar("T2", bound=Foo)

T3 = TypeVar("T3", Foo, Wrapper)

def tweak_me(arg: int):
"""
prior docstring
Expand Down