Skip to content

Commit f8c3e05

Browse files
committed
🐛 sparse: fix csc and csr return type of count_nonzero with axis given
1 parent 2d75c85 commit f8c3e05

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

scipy-stubs/sparse/_csc.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, Generic, Literal, overload
22
from typing_extensions import TypeIs, TypeVar, override
33

4+
import numpy as np
45
import optype as op
6+
import optype.numpy as onp
57
from ._base import sparray
68
from ._compressed import _cs_matrix
79
from ._matrix import spmatrix
@@ -24,6 +26,12 @@ class _csc_base(_cs_matrix[_SCT, tuple[int, int]], Generic[_SCT]):
2426
@override
2527
def shape(self, /) -> tuple[int, int]: ...
2628

29+
#
30+
@overload # type: ignore[explicit-override]
31+
def count_nonzero(self, /, axis: None = None) -> int: ...
32+
@overload
33+
def count_nonzero(self, /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
34+
2735
class csc_array(_csc_base[_SCT], sparray, Generic[_SCT]): ...
2836

2937
class csc_matrix(_csc_base[_SCT], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]

scipy-stubs/sparse/_csr.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, Generic, Literal, overload
22
from typing_extensions import TypeIs, TypeVar, override
33

4+
import numpy as np
45
import optype as op
6+
import optype.numpy as onp
57
from ._base import sparray
68
from ._compressed import _cs_matrix
79
from ._matrix import spmatrix
@@ -22,6 +24,16 @@ class _csr_base(_cs_matrix[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]):
2224
@override
2325
def format(self, /) -> Literal["csr"]: ...
2426

27+
#
28+
@overload # type: ignore[explicit-override]
29+
def count_nonzero(self, /, axis: None = None) -> int: ...
30+
@overload
31+
def count_nonzero(self: _csr_base[Any, tuple[int]], /, axis: op.CanIndex) -> int: ...
32+
@overload
33+
def count_nonzero(self: _csr_base[Any, tuple[int, int]], /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
34+
@overload
35+
def count_nonzero(self: csr_array, /, axis: op.CanIndex) -> int | onp.Array1D[np.intp]: ... # type: ignore[misc]
36+
2537
class csr_array(_csr_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...
2638

2739
class csr_matrix(_csr_base[_SCT, tuple[int, int]], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]

0 commit comments

Comments
 (0)