77import types
88from typing import Any
99
10- from pybind11_stubgen .parser .errors import NameResolutionError
10+ from pybind11_stubgen .parser .errors import NameResolutionError , ParserError
1111from pybind11_stubgen .parser .interface import IParser
1212from pybind11_stubgen .structs import (
1313 Alias ,
2727 ResolvedType ,
2828 Value ,
2929)
30- from pybind11_stubgen .typing_ext import FixedSize
30+ from pybind11_stubgen .typing_ext import DynamicSize , FixedSize
3131
3232
3333class RemoveSelfAnnotation (IParser ):
@@ -252,6 +252,12 @@ def handle_type(self, type_: type) -> QualifiedName:
252252
253253 return result
254254
255+ def report_error (self , error : ParserError ):
256+ if isinstance (error , NameResolutionError ):
257+ if error .name [0 ] in ["PyCapsule" ]:
258+ return
259+ super ().report_error (error )
260+
255261
256262class FixRedundantBuiltinsAnnotation (IParser ):
257263 def handle_attribute (self , path : QualifiedName , attr : Any ) -> Attribute | None :
@@ -451,7 +457,14 @@ def value_to_repr(self, value: Any) -> str:
451457
452458
453459class FixNumpyArrayDimAnnotation (IParser ):
454- __ndarray_name = QualifiedName .from_str ("numpy.ndarray" )
460+ __array_names : set [QualifiedName ] = {
461+ QualifiedName .from_str ("numpy.ndarray" ),
462+ * (
463+ QualifiedName .from_str (f"scipy.sparse.{ storage } _{ arr } " )
464+ for storage in ["bsr" , "coo" , "csr" , "csc" , "dia" , "dok" , "lil" ]
465+ for arr in ["array" , "matrix" ]
466+ ),
467+ }
455468 __annotated_name = QualifiedName .from_str ("typing.Annotated" )
456469 numpy_primitive_types : set [QualifiedName ] = set (
457470 map (
@@ -475,47 +488,79 @@ def parse_annotation_str(
475488 self , annotation_str : str
476489 ) -> ResolvedType | InvalidExpression | Value :
477490 # Affects types of the following pattern:
478- # numpy.ndarray[PRIMITIVE_TYPE[*DIMS]]
479- # Annotated[numpy.ndarray, PRIMITIVE_TYPE, FixedSize[*DIMS]]
491+ # ARRAY_T[PRIMITIVE_TYPE[*DIMS], *FLAGS]
492+ # Replace with:
493+ # Annotated[ARRAY_T, PRIMITIVE_TYPE, FixedSize/DynamicSize[*DIMS], *FLAGS]
480494
481495 result = super ().parse_annotation_str (annotation_str )
482496 if (
483497 not isinstance (result , ResolvedType )
484- or result .name != self .__ndarray_name
498+ or result .name not in self .__array_names
485499 or result .parameters is None
486- or len (result .parameters ) != 1
500+ or len (result .parameters ) == 0
487501 ):
488502 return result
489- param = result .parameters [0 ]
503+
504+ scalar_with_dims = result .parameters [0 ] # e.g. numpy.float64[32, 32]
505+ flags = result .parameters [1 :]
506+
490507 if (
491- not isinstance (param , ResolvedType )
492- or param .name not in self .numpy_primitive_types
493- or param .parameters is None
494- or any (not isinstance (dim , Value ) for dim in param .parameters )
508+ not isinstance (scalar_with_dims , ResolvedType )
509+ or scalar_with_dims .name not in self .numpy_primitive_types
510+ or (
511+ scalar_with_dims .parameters is not None
512+ and any (
513+ not isinstance (dim , Value ) for dim in scalar_with_dims .parameters
514+ )
515+ )
495516 ):
496517 return result
497518
498- # isinstance check is redundant, but makes mypy happy
499- dims = [int (dim .repr ) for dim in param .parameters if isinstance (dim , Value )]
500-
501- # override result with Annotated[...]
502519 result = ResolvedType (
503520 name = self .__annotated_name ,
504521 parameters = [
505- ResolvedType ( self .__ndarray_name ),
506- ResolvedType (param .name ),
522+ self .parse_annotation_str ( str ( result . name ) ),
523+ ResolvedType (scalar_with_dims .name ),
507524 ],
508525 )
509526
510- if param .parameters is not None :
511- # TRICK: Use `self.parse_type` to make `FixedSize`
512- # properly added to the list of imports
513- self .handle_type (FixedSize )
514- assert result .parameters is not None
515- result .parameters += [self .handle_value (FixedSize (* dims ))]
527+ if (
528+ scalar_with_dims .parameters is not None
529+ and len (scalar_with_dims .parameters ) >= 0
530+ ):
531+ result .parameters += [
532+ self .handle_value (
533+ self ._cook_dimension_parameters (scalar_with_dims .parameters )
534+ )
535+ ]
536+
537+ result .parameters += flags
516538
517539 return result
518540
541+ def _cook_dimension_parameters (
542+ self , dimensions : list [Value ]
543+ ) -> FixedSize | DynamicSize :
544+ all_ints = True
545+ new_params = []
546+ for dim_param in dimensions :
547+ try :
548+ dim = int (dim_param .repr )
549+ except ValueError :
550+ dim = dim_param .repr
551+ all_ints = False
552+ new_params .append (dim )
553+
554+ if all_ints :
555+ return_t = FixedSize
556+ else :
557+ return_t = DynamicSize
558+
559+ # TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
560+ # properly added to the list of imports
561+ self .handle_type (FixedSize )
562+ return return_t (* new_params )
563+
519564
520565class FixNumpyArrayRemoveParameters (IParser ):
521566 __ndarray_name = QualifiedName .from_str ("numpy.ndarray" )
@@ -529,6 +574,34 @@ def parse_annotation_str(
529574 return result
530575
531576
577+ class FixNumpyArrayFlags (IParser ):
578+ __ndarray_name = QualifiedName .from_str ("numpy.ndarray" )
579+ __flags : set [QualifiedName ] = {
580+ QualifiedName .from_str ("flags.writeable" ),
581+ QualifiedName .from_str ("flags.c_contiguous" ),
582+ QualifiedName .from_str ("flags.f_contiguous" ),
583+ }
584+
585+ def parse_annotation_str (
586+ self , annotation_str : str
587+ ) -> ResolvedType | InvalidExpression | Value :
588+ result = super ().parse_annotation_str (annotation_str )
589+ if isinstance (result , ResolvedType ) and result .name == self .__ndarray_name :
590+ if result .parameters is not None :
591+ for param in result .parameters :
592+ if param .name in self .__flags :
593+ param .name = QualifiedName .from_str (
594+ f"numpy.ndarray.{ param .name } "
595+ )
596+
597+ return result
598+
599+ def report_error (self , error : ParserError ) -> None :
600+ if isinstance (error , NameResolutionError ) and error .name in self .__flags :
601+ return
602+ super ().report_error (error )
603+
604+
532605class FixRedundantMethodsFromBuiltinObject (IParser ):
533606 def handle_method (self , path : QualifiedName , method : Any ) -> list [Method ]:
534607 result = super ().handle_method (path , method )
0 commit comments