66import re
77import sys
88import types
9+ from collections import defaultdict
910from logging import getLogger
1011from typing import Any , Sequence
1112
1213from pybind11_stubgen .parser .errors import (
14+ AmbiguousEnumError ,
1315 InvalidExpressionError ,
1416 NameResolutionError ,
1517 ParserError ,
@@ -999,13 +1001,56 @@ def parse_value_str(self, value: str) -> Value | InvalidExpression:
9991001
10001002
10011003class RewritePybind11EnumValueRepr (IParser ):
1004+ """Reformat pybind11-generated invalid enum value reprs.
1005+
1006+ For example, pybind11 may generate a `__doc__` like this:
1007+ >>> "set_color(self, color: <ConsoleForegroundColor.Blue: 34>) -> None:\n "
1008+
1009+ Which is invalid python syntax. This parser will rewrite the generated stub to:
1010+ >>> from demo._bindings.enum import ConsoleForegroundColor
1011+ >>> def set_color(self, color: ConsoleForegroundColor.Blue) -> None:
1012+ >>> ...
1013+
1014+ Since `pybind11_stubgen` encounters the values corresponding to these reprs as it
1015+ parses the modules, it can automatically replace these invalid expressions with the
1016+ corresponding `Value` and `Import` as it encounters them. There are 3 cases for an
1017+ `Argument` whose `default` is an enum `InvalidExpression`:
1018+
1019+ 1. The `InvalidExpression` repr corresponds to exactly one enum field definition.
1020+ The `InvalidExpression` is simply replaced by the corresponding `Value`.
1021+ 2. The `InvalidExpression` repr corresponds to multiple enum field definitions. An
1022+ `AmbiguousEnumError` is reported.
1023+ 3. The `InvalidExpression` repr corresponds to no enum field definitions. An
1024+ `InvalidExpressionError` is reported.
1025+
1026+ Attributes:
1027+ _pybind11_enum_pattern: Pattern matching pybind11 enum field reprs.
1028+ _unknown_enum_classes: Set of the str names of enum classes whose reprs were not
1029+ seen.
1030+ _invalid_default_arguments: Per module invalid arguments. Used to know which
1031+ enum imports to add to the current module.
1032+ _repr_to_value_and_import: Saved safe print values of enum field reprs and the
1033+ import to add to a module when when that repr is seen.
1034+ _repr_to_invalid_default_arguments: Groups of arguments whose default values are
1035+ `InvalidExpression`s. This is only used until the first time each repr is
1036+ seen. Left over groups will raise an error, which may be fixed using
1037+ `--enum-class-locations` or suppressed using `--ignore-invalid-expressions`.
1038+ _invalid_default_argument_to_module: Maps individual invalid default arguments
1039+ to the module containing them. Used to know which enum imports to add to
1040+ which module.
1041+ """
1042+
10021043 _pybind11_enum_pattern = re .compile (r"<(?P<enum>\w+(\.\w+)+): (?P<value>-?\d+)>" )
1003- # _pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>\d+)>")
10041044 _unknown_enum_classes : set [str ] = set ()
1045+ _invalid_default_arguments : list [Argument ] = []
1046+ _repr_to_value_and_import : dict [str , set [tuple [Value , Import ]]] = defaultdict (set )
1047+ _repr_to_invalid_default_arguments : dict [str , set [Argument ]] = defaultdict (set )
1048+ _invalid_default_argument_to_module : dict [Argument , Module ] = {}
10051049
10061050 def __init__ (self ):
10071051 super ().__init__ ()
10081052 self ._pybind11_enum_locations : dict [re .Pattern , str ] = {}
1053+ self ._is_finalizing = False
10091054
10101055 def set_pybind11_enum_locations (self , locations : dict [re .Pattern , str ]):
10111056 self ._pybind11_enum_locations = locations
@@ -1024,17 +1069,104 @@ def parse_value_str(self, value: str) -> Value | InvalidExpression:
10241069 return Value (repr = f"{ enum_class .name } .{ entry } " , is_print_safe = True )
10251070 return super ().parse_value_str (value )
10261071
1072+ def handle_module (
1073+ self , path : QualifiedName , module : types .ModuleType
1074+ ) -> Module | None :
1075+ # we may be handling a module within a module, so save the parent's invalid
1076+ # arguments on the stack as we handle this module
1077+ parent_module_invalid_arguments = self ._invalid_default_arguments
1078+ self ._invalid_default_arguments = []
1079+ result = super ().handle_module (path , module )
1080+
1081+ if result is None :
1082+ self ._invalid_default_arguments = parent_module_invalid_arguments
1083+ return None
1084+
1085+ # register each argument to the current module
1086+ while self ._invalid_default_arguments :
1087+ arg = self ._invalid_default_arguments .pop ()
1088+ assert isinstance (arg .default , InvalidExpression )
1089+ repr_ = arg .default .text
1090+ self ._repr_to_invalid_default_arguments [repr_ ].add (arg )
1091+ self ._invalid_default_argument_to_module [arg ] = result
1092+
1093+ self ._invalid_default_arguments = parent_module_invalid_arguments
1094+ return result
1095+
1096+ def handle_function (self , path : QualifiedName , func : Any ) -> list [Function ]:
1097+ result = super ().handle_function (path , func )
1098+
1099+ for f in result :
1100+ for arg in f .args :
1101+ if isinstance (arg .default , InvalidExpression ):
1102+ # this argument will be registered to the current module
1103+ self ._invalid_default_arguments .append (arg )
1104+
1105+ return result
1106+
1107+ def handle_attribute (self , path : QualifiedName , attr : Any ) -> Attribute | None :
1108+ module = inspect .getmodule (attr )
1109+ repr_ = repr (attr )
1110+
1111+ if module is not None :
1112+ module_path = QualifiedName .from_str (module .__name__ )
1113+ is_source_module = path [: len (module_path )] == module_path
1114+ is_alias = ( # could be an `.export_values()` alias, which we want to avoid
1115+ is_source_module
1116+ and not inspect .isclass (getattr (module , path [len (module_path )]))
1117+ )
1118+
1119+ if not is_alias and is_source_module :
1120+ # register one of the possible sources of this repr
1121+ self ._repr_to_value_and_import [repr_ ].add (
1122+ (
1123+ Value (repr = "." .join (path ), is_print_safe = True ),
1124+ Import (name = None , origin = module_path ),
1125+ )
1126+ )
1127+
1128+ return super ().handle_attribute (path , attr )
1129+
10271130 def report_error (self , error : ParserError ) -> None :
1028- if isinstance (error , InvalidExpressionError ):
1131+ # defer reporting invalid enum expressions until finalization
1132+ if not self ._is_finalizing and isinstance (error , InvalidExpressionError ):
10291133 match = self ._pybind11_enum_pattern .match (error .expression )
10301134 if match is not None :
1135+ return
1136+ super ().report_error (error )
1137+
1138+ def finalize (self ) -> None :
1139+ self ._is_finalizing = True
1140+ for repr_ , args in self ._repr_to_invalid_default_arguments .items ():
1141+ match = self ._pybind11_enum_pattern .match (repr_ )
1142+ if match is None :
1143+ pass
1144+ elif repr_ not in self ._repr_to_value_and_import :
10311145 enum_qual_name = match .group ("enum" )
1032- enum_class_str , entry = enum_qual_name .rsplit ("." , maxsplit = 1 )
1146+ enum_class_str , _ = enum_qual_name .rsplit ("." , maxsplit = 1 )
10331147 self ._unknown_enum_classes .add (enum_class_str )
1034- super ().report_error (error )
1148+ self .report_error (InvalidExpressionError (repr_ ))
1149+ elif len (self ._repr_to_value_and_import [repr_ ]) > 1 :
1150+ self .report_error (
1151+ AmbiguousEnumError (repr_ , * self ._repr_to_value_and_import [repr_ ])
1152+ )
1153+ else :
1154+ # fix the invalid enum expressions
1155+ value , import_ = self ._repr_to_value_and_import [repr_ ].pop ()
1156+ for arg in args :
1157+ module = self ._invalid_default_argument_to_module [arg ]
1158+ if module .origin == import_ .origin :
1159+ arg .default = Value (
1160+ repr = value .repr [len (str (module .origin )) + 1 :],
1161+ is_print_safe = True ,
1162+ )
1163+ else :
1164+ arg .default = value
1165+ module .imports .add (import_ )
10351166
1036- def finalize (self ):
10371167 if self ._unknown_enum_classes :
1168+ # TODO: does this case still exist in practice? How would pybind11 display
1169+ # a repr for an enum field whose definition we did not see while parsing?
10381170 logger .warning (
10391171 "Enum-like str representations were found with no "
10401172 "matching mapping to the enum class location.\n "
0 commit comments