1
1
from typing import Any , Generic , Literal , overload
2
2
from typing_extensions import TypeIs , TypeVar , override
3
3
4
+ import numpy as np
4
5
import optype as op
6
+ import optype .numpy as onp
5
7
from ._base import sparray
6
8
from ._compressed import _cs_matrix
7
9
from ._matrix import spmatrix
@@ -22,6 +24,16 @@ class _csr_base(_cs_matrix[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]):
22
24
@override
23
25
def format (self , / ) -> Literal ["csr" ]: ...
24
26
27
+ #
28
+ @overload # type: ignore[explicit-override]
29
+ def count_nonzero (self , / , axis : None = None ) -> int : ...
30
+ @overload
31
+ def count_nonzero (self : _csr_base [Any , tuple [int ]], / , axis : op .CanIndex ) -> int : ...
32
+ @overload
33
+ def count_nonzero (self : _csr_base [Any , tuple [int , int ]], / , axis : op .CanIndex ) -> onp .Array1D [np .intp ]: ...
34
+ @overload
35
+ def count_nonzero (self : csr_array , / , axis : op .CanIndex ) -> int | onp .Array1D [np .intp ]: ... # type: ignore[misc]
36
+
25
37
class csr_array (_csr_base [_SCT , _ShapeT_co ], sparray , Generic [_SCT , _ShapeT_co ]): ...
26
38
27
39
class csr_matrix (_csr_base [_SCT , tuple [int , int ]], spmatrix [_SCT ], Generic [_SCT ]): # type: ignore[misc]
0 commit comments