3131 Property ,
3232 QualifiedName ,
3333 ResolvedType ,
34+ TypeVar_ ,
3435 Value ,
3536)
3637from pybind11_stubgen .typing_ext import DynamicSize , FixedSize
@@ -335,6 +336,7 @@ class FixTypingTypeNames(IParser):
335336 "Iterator" ,
336337 "KeysView" ,
337338 "List" ,
339+ "Literal" ,
338340 "Optional" ,
339341 "Sequence" ,
340342 "Set" ,
@@ -360,12 +362,28 @@ def __init__(self):
360362 super ().__init__ ()
361363 if sys .version_info < (3 , 9 ):
362364 self .__typing_extensions_names .add (Identifier ("Annotated" ))
365+ if sys .version_info < (3 , 8 ):
366+ self .__typing_extensions_names .add (Identifier ("Literal" ))
363367
364368 def parse_annotation_str (
365369 self , annotation_str : str
366370 ) -> ResolvedType | InvalidExpression | Value :
367371 result = super ().parse_annotation_str (annotation_str )
368- if not isinstance (result , ResolvedType ) or len (result .name ) != 1 :
372+ return self ._parse_annotation_str (result )
373+
374+ def _parse_annotation_str (
375+ self , result : ResolvedType | InvalidExpression | Value
376+ ) -> ResolvedType | InvalidExpression | Value :
377+ if not isinstance (result , ResolvedType ):
378+ return result
379+
380+ result .parameters = (
381+ [self ._parse_annotation_str (p ) for p in result .parameters ]
382+ if result .parameters is not None
383+ else None
384+ )
385+
386+ if len (result .name ) != 1 :
369387 return result
370388
371389 word = result .name [0 ]
@@ -582,6 +600,136 @@ def report_error(self, error: ParserError) -> None:
582600 super ().report_error (error )
583601
584602
603+ class FixNumpyArrayDimTypeVar (IParser ):
604+ __array_names : set [QualifiedName ] = {QualifiedName .from_str ("numpy.ndarray" )}
605+ numpy_primitive_types = FixNumpyArrayDimAnnotation .numpy_primitive_types
606+
607+ __DIM_VARS : set [str ] = set ()
608+
609+ def handle_module (
610+ self , path : QualifiedName , module : types .ModuleType
611+ ) -> Module | None :
612+ result = super ().handle_module (path , module )
613+ if result is None :
614+ return None
615+
616+ if self .__DIM_VARS :
617+ # the TypeVar_'s generated code will reference `typing`
618+ result .imports .add (
619+ Import (name = None , origin = QualifiedName .from_str ("typing" ))
620+ )
621+
622+ for name in self .__DIM_VARS :
623+ result .type_vars .append (
624+ TypeVar_ (
625+ name = Identifier (name ),
626+ bound = self .parse_annotation_str ("int" ),
627+ ),
628+ )
629+
630+ self .__DIM_VARS .clear ()
631+
632+ return result
633+
634+ def parse_annotation_str (
635+ self , annotation_str : str
636+ ) -> ResolvedType | InvalidExpression | Value :
637+ # Affects types of the following pattern:
638+ # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
639+ # Replace with:
640+ # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
641+
642+ result = super ().parse_annotation_str (annotation_str )
643+
644+ if not isinstance (result , ResolvedType ):
645+ return result
646+
647+ # handle unqualified, single-letter annotation as a TypeVar
648+ if len (result .name ) == 1 and len (result .name [0 ]) == 1 :
649+ result .name = QualifiedName .from_str (result .name [0 ].upper ())
650+ self .__DIM_VARS .add (result .name [0 ])
651+
652+ if result .name not in self .__array_names :
653+ return result
654+
655+ # ndarray is generic and should have 2 type arguments
656+ if result .parameters is None or len (result .parameters ) == 0 :
657+ result .parameters = [
658+ self .parse_annotation_str ("Any" ),
659+ ResolvedType (
660+ name = QualifiedName .from_str ("numpy.dtype" ),
661+ parameters = [self .parse_annotation_str ("Any" )],
662+ ),
663+ ]
664+ return result
665+
666+ scalar_with_dims = result .parameters [0 ] # e.g. numpy.float64[32, 32]
667+
668+ if (
669+ not isinstance (scalar_with_dims , ResolvedType )
670+ or scalar_with_dims .name not in self .numpy_primitive_types
671+ ):
672+ return result
673+
674+ dtype = ResolvedType (
675+ name = QualifiedName .from_str ("numpy.dtype" ),
676+ parameters = [ResolvedType (name = scalar_with_dims .name )],
677+ )
678+
679+ shape = self .parse_annotation_str ("Any" )
680+ if (
681+ scalar_with_dims .parameters is not None
682+ and len (scalar_with_dims .parameters ) > 0
683+ ):
684+ dims = self .__to_dims (scalar_with_dims .parameters )
685+ if dims is not None :
686+ shape = self .parse_annotation_str ("Tuple" )
687+ assert isinstance (shape , ResolvedType )
688+ shape .parameters = []
689+ for dim in dims :
690+ if isinstance (dim , int ):
691+ # self.parse_annotation_str will qualify Literal with either
692+ # typing or typing_extensions and add the import to the module
693+ literal_dim = self .parse_annotation_str ("Literal" )
694+ assert isinstance (literal_dim , ResolvedType )
695+ literal_dim .parameters = [Value (repr = str (dim ))]
696+ shape .parameters .append (literal_dim )
697+ else :
698+ shape .parameters .append (
699+ ResolvedType (name = QualifiedName .from_str (dim ))
700+ )
701+
702+ result .parameters = [shape , dtype ]
703+ return result
704+
705+ def __to_dims (
706+ self , dimensions : Sequence [ResolvedType | Value | InvalidExpression ]
707+ ) -> list [int | str ] | None :
708+ result : list [int | str ] = []
709+ for dim_param in dimensions :
710+ if isinstance (dim_param , Value ):
711+ try :
712+ dim = int (dim_param .repr )
713+ except ValueError :
714+ return None
715+ elif isinstance (dim_param , ResolvedType ):
716+ dim = str (dim_param )
717+ else :
718+ return None
719+ result .append (dim )
720+ return result
721+
722+ def report_error (self , error : ParserError ) -> None :
723+ if (
724+ isinstance (error , NameResolutionError )
725+ and len (error .name ) == 1
726+ and error .name [0 ] in self .__DIM_VARS
727+ ):
728+ # allow type variables, which are manually resolved in `handle_module`
729+ return
730+ super ().report_error (error )
731+
732+
585733class FixNumpyArrayRemoveParameters (IParser ):
586734 __ndarray_name = QualifiedName .from_str ("numpy.ndarray" )
587735
@@ -594,24 +742,40 @@ def parse_annotation_str(
594742 return result
595743
596744
745+ class FixScipyTypeArguments (IParser ):
746+ def parse_annotation_str (
747+ self , annotation_str : str
748+ ) -> ResolvedType | InvalidExpression | Value :
749+ result = super ().parse_annotation_str (annotation_str )
750+
751+ if not isinstance (result , ResolvedType ):
752+ return result
753+
754+ # scipy.sparse arrays/matrices are not currently generic and do not accept type
755+ # arguments
756+ if result .name [:2 ] == ("scipy" , "sparse" ):
757+ result .parameters = None
758+
759+ return result
760+
761+
597762class FixNumpyDtype (IParser ):
598763 __numpy_dtype = QualifiedName .from_str ("numpy.dtype" )
599764
600765 def parse_annotation_str (
601766 self , annotation_str : str
602767 ) -> ResolvedType | InvalidExpression | Value :
603768 result = super ().parse_annotation_str (annotation_str )
604- if (
605- not isinstance (result , ResolvedType )
606- or len (result .name ) != 1
607- or result .parameters is not None
608- ):
609- return result
610769
611- word = result .name [0 ]
612- if word != Identifier ("dtype" ):
770+ if not isinstance (result , ResolvedType ) or result .parameters :
613771 return result
614- return ResolvedType (name = self .__numpy_dtype )
772+
773+ # numpy.dtype is generic and should have a type argument
774+ if result .name [:1 ] == ("dtype" ,) or result .name [:2 ] == ("numpy" , "dtype" ):
775+ result .name = self .__numpy_dtype
776+ result .parameters = [self .parse_annotation_str ("Any" )]
777+
778+ return result
615779
616780
617781class FixNumpyArrayFlags (IParser ):
0 commit comments