Skip to content

Commit 0d95c22

Browse files
Improved stubs for block_array and bmat
1 parent 05684da commit 0d95c22

File tree

2 files changed

+177
-59
lines changed

2 files changed

+177
-59
lines changed

scipy-stubs/sparse/_construct.pyi

Lines changed: 159 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ _Offsets: TypeAlias = onp.ToInt | onp.ToInt1D
8383

8484
_DataRVS: TypeAlias = Callable[[int], onp.ArrayND[_Numeric]]
8585

86-
_ToBlocks: TypeAlias = Seq[Seq[_spbase[_SCT] | None]] | onp.ArrayND[np.object_]
86+
_ToBlocksArray: TypeAlias = onp.ArrayND[np.object_]
87+
_ToBlocksSPArray: TypeAlias = Seq[Seq[_SpArray2D[_SCT] | None]]
88+
_ToBlocksSPMatrix: TypeAlias = Seq[Seq[_SpMatrix[_SCT] | None]]
89+
_ToBlocksCanStack: TypeAlias = Seq[Seq[_CanStack[_T] | None]]
90+
_ToBlocksCanStackAs: TypeAlias = Seq[Seq[_CanStackAs[_SCT0, _T] | None]]
91+
_ToBlocksUnkown: TypeAlias = _ToBlocksSPArray | _ToBlocksArray
8792

8893
@type_check_only
8994
class _DataSampler(Protocol):
@@ -1385,19 +1390,19 @@ def kronsum(A: onp.ToComplex2D, B: onp.ToComplex2D, format: _Format | None = Non
13851390
@overload # sparray, format: <default>, dtype: <default>
13861391
def hstack(blocks: Seq[_CanStack[_T]], format: None = None, dtype: None = None) -> _T: ...
13871392
@overload # sparray, format: "bsr", dtype: <default>
1388-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT0]: ...
1393+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT]: ...
13891394
@overload # sparray, format: "coo", dtype: <default>
1390-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtCOO, dtype: None = None) -> _COOArray2D[_SCT0]: ...
1395+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtCOO, dtype: None = None) -> _COOArray2D[_SCT]: ...
13911396
@overload # sparray, format: "csc", dtype: <default>
1392-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT0]: ...
1397+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT]: ...
13931398
@overload # sparray, format: "csr", dtype: <default>
1394-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT0]: ...
1399+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT]: ...
13951400
@overload # sparray, format: "dia", dtype: <default>
1396-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT0]: ...
1401+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT]: ...
13971402
@overload # sparray, format: "dok", dtype: <default>
1398-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT0]: ...
1403+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT]: ...
13991404
@overload # sparray, format: "lil", dtype: <default>
1400-
def hstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT0]: ...
1405+
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT]: ...
14011406

14021407
#
14031408
@overload # sparray, format: <default>, dtype: bool-like
@@ -1516,19 +1521,19 @@ def hstack(blocks: Seq[_spbase], format: _Format, dtype: npt.DTypeLike | None =
15161521
@overload # sparray, format: <default>, dtype: <default>
15171522
def vstack(blocks: Seq[_CanStack[_T]], format: None = None, dtype: None = None) -> _T: ...
15181523
@overload # sparray, format: "bsr", dtype: <default>
1519-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT0]: ...
1524+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT]: ...
15201525
@overload # sparray, format: "coo", dtype: <default>
1521-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtCOO, dtype: None = None) -> _COOArray2D[_SCT0]: ...
1526+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtCOO, dtype: None = None) -> _COOArray2D[_SCT]: ...
15221527
@overload # sparray, format: "csc", dtype: <default>
1523-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT0]: ...
1528+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT]: ...
15241529
@overload # sparray, format: "csr", dtype: <default>
1525-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT0]: ...
1530+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT]: ...
15261531
@overload # sparray, format: "dia", dtype: <default>
1527-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT0]: ...
1532+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT]: ...
15281533
@overload # sparray, format: "dok", dtype: <default>
1529-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT0]: ...
1534+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT]: ...
15301535
@overload # sparray, format: "lil", dtype: <default>
1531-
def vstack(blocks: Seq[_SpArray[_SCT0]], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT0]: ...
1536+
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT]: ...
15321537

15331538
#
15341539
@overload # sparray, format: <default>, dtype: bool-like
@@ -1643,52 +1648,152 @@ def vstack(blocks: Seq[_SpArray], format: _FmtLIL, dtype: npt.DTypeLike) -> _LIL
16431648
def vstack(blocks: Seq[_spbase], format: _Format, dtype: npt.DTypeLike | None = None) -> Incomplete: ...
16441649

16451650
###
1646-
# TODO(jorenham): Use `_CanStack` here, which requires a way to map matrix types to array types.
16471651
@overload # blocks: <known>, format: <default>, dtype: <default>
1648-
def block_array(blocks: _ToBlocks[_SCT], *, format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...
1649-
@overload # blocks: <known>, format: <otherwise>, dtype: <default>
1650-
def block_array(blocks: _ToBlocks[_SCT], *, format: _FmtNonCOO, dtype: None = None) -> _SpArray2D[_SCT]: ...
1651-
1652-
#
1653-
@overload # blocks: <unknown>, format: <default>, dtype: <known>
1654-
def block_array(blocks: _ToBlocks, *, format: _FmtCOO | None = None, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1655-
@overload # blocks: <unknown>, format: <otherwise>, dtype: <known>
1656-
def block_array(blocks: _ToBlocks, *, format: _FmtNonCOO, dtype: onp.ToDType[_SCT]) -> _SpArray2D[_SCT]: ...
1657-
1658-
#
1659-
@overload # blocks: <unknown>, format: <default>, dtype: <unknown>
1660-
def block_array(blocks: _ToBlocks, *, format: _FmtCOO | None = None, dtype: npt.DTypeLike) -> _COOArray2D: ...
1661-
@overload # blocks: <unknown>, format: <otherwise>, dtype: <unknown>
1662-
def block_array(blocks: _ToBlocks, *, format: _FmtNonCOO, dtype: npt.DTypeLike) -> _SpArray2D: ...
1652+
def block_array(blocks: _ToBlocksCanStack[_T], format: None = None, dtype: None = None) -> _T: ...
1653+
@overload # blocks: <known>, format: "bsr", dtype: <default>
1654+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT]: ...
1655+
@overload # blocks: <known>, format: "coo", dtype: <default>
1656+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtCOO, dtype: None = None) -> _COOArray2D[_SCT]: ...
1657+
@overload # blocks: <known>, format: "csc", dtype: <default>
1658+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT]: ...
1659+
@overload # blocks: <known>, format: "csr", dtype: <default>
1660+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT]: ...
1661+
@overload # blocks: <known>, format: "dia", dtype: <default>
1662+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT]: ...
1663+
@overload # blocks: <known>, format: "dok", dtype: <default>
1664+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT]: ...
1665+
@overload # blocks: <known>, format: "lil", dtype: <default>
1666+
def block_array(blocks: _ToBlocksSPArray[_SCT], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT]: ...
1667+
1668+
#
1669+
@overload # blocks: <known>, format: <default>, dtype: bool-like
1670+
def block_array(blocks: _ToBlocksCanStackAs[np.bool_, _T], format: None = None, *, dtype: onp.AnyBoolDType) -> _T: ...
1671+
@overload # blocks: <unknown>, format: "bsr", dtype: bool-like
1672+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: onp.AnyBoolDType) -> _BSRArray[np.bool_]: ...
1673+
@overload # blocks: <unknown>, format: "coo", dtype: bool-like
1674+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: onp.AnyBoolDType) -> _COOArray2D[np.bool_]: ...
1675+
@overload # blocks: <unknown>, format: "csc", dtype: bool-like
1676+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: onp.AnyBoolDType) -> _CSCArray[np.bool_]: ...
1677+
@overload # blocks: <unknown>, format: "csr", dtype: bool-like
1678+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: onp.AnyBoolDType) -> _CSRArray2D[np.bool_]: ...
1679+
@overload # blocks: <unknown>, format: "dia", dtype: bool-like
1680+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: onp.AnyBoolDType) -> _DIAArray[np.bool_]: ...
1681+
@overload # blocks: <unknown>, format: "dok", dtype: bool-like
1682+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: onp.AnyBoolDType) -> _DOKArray2D[np.bool_]: ...
1683+
@overload # blocks: <unknown>, format: "lil", dtype: bool-like
1684+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.AnyBoolDType) -> _LILArray[np.bool_]: ...
1685+
1686+
#
1687+
@overload # blocks: <known>, format: <default>, dtype: int-like
1688+
def block_array(blocks: _ToBlocksCanStackAs[np.int_, _T], format: None = None, *, dtype: onp.AnyIntDType) -> _T: ...
1689+
@overload # blocks: <unknown>, format: "bsr", dtype: int-like
1690+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: onp.AnyIntDType) -> _BSRArray[np.int_]: ...
1691+
@overload # blocks: <unknown>, format: "coo", dtype: int-like
1692+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: onp.AnyIntDType) -> _COOArray2D[np.int_]: ...
1693+
@overload # blocks: <unknown>, format: "csc", dtype: int-like
1694+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: onp.AnyIntDType) -> _CSCArray[np.int_]: ...
1695+
@overload # blocks: <unknown>, format: "csr", dtype: int-like
1696+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: onp.AnyIntDType) -> _CSRArray2D[np.int_]: ...
1697+
@overload # blocks: <unknown>, format: "dia", dtype: int-like
1698+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: onp.AnyIntDType) -> _DIAArray[np.int_]: ...
1699+
@overload # blocks: <unknown>, format: "dok", dtype: int-like
1700+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: onp.AnyIntDType) -> _DOKArray2D[np.int_]: ...
1701+
@overload # blocks: <unknown>, format: "lil", dtype: int-like
1702+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.AnyIntDType) -> _LILArray[np.int_]: ...
1703+
1704+
#
1705+
@overload # blocks: <known>, format: <default>, dtype: float64-like
1706+
def block_array(blocks: _ToBlocksCanStackAs[np.float64, _T], format: None = None, *, dtype: onp.AnyFloat64DType) -> _T: ...
1707+
@overload # blocks: <unknown>, format: "bsr", dtype: float64-like
1708+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: onp.AnyFloat64DType) -> _BSRArray[np.float64]: ...
1709+
@overload # blocks: <unknown>, format: "coo", dtype: float64-like
1710+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: onp.AnyFloat64DType) -> _COOArray2D[np.float64]: ...
1711+
@overload # blocks: <unknown>, format: "csc", dtype: float64-like
1712+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: onp.AnyFloat64DType) -> _CSCArray[np.float64]: ...
1713+
@overload # blocks: <unknown>, format: "csr", dtype: float64-like
1714+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: onp.AnyFloat64DType) -> _CSRArray2D[np.float64]: ...
1715+
@overload # blocks: <unknown>, format: "dia", dtype: float64-like
1716+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: onp.AnyFloat64DType) -> _DIAArray[np.float64]: ...
1717+
@overload # blocks: <unknown>, format: "dok", dtype: float64-like
1718+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: onp.AnyFloat64DType) -> _DOKArray2D[np.float64]: ...
1719+
@overload # blocks: <unknown>, format: "lil", dtype: float64-like
1720+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.AnyFloat64DType) -> _LILArray[np.float64]: ...
1721+
1722+
#
1723+
@overload # blocks: <known>, format: <default>, dtype: complex128-like
1724+
def block_array(blocks: _ToBlocksCanStackAs[np.complex128, _T], format: None = None, *, dtype: onp.AnyComplex128DType) -> _T: ...
1725+
@overload # blocks: <unknown>, format: "bsr", dtype: complex128-like
1726+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: onp.AnyComplex128DType) -> _BSRArray[np.complex128]: ...
1727+
@overload # blocks: <unknown>, format: "coo", dtype: complex128-like
1728+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: onp.AnyComplex128DType) -> _COOArray2D[np.complex128]: ...
1729+
@overload # blocks: <unknown>, format: "csc", dtype: complex128-like
1730+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: onp.AnyComplex128DType) -> _CSCArray[np.complex128]: ...
1731+
@overload # blocks: <unknown>, format: "csr", dtype: complex128-like
1732+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: onp.AnyComplex128DType) -> _CSRArray2D[np.complex128]: ...
1733+
@overload # blocks: <unknown>, format: "dia", dtype: complex128-like
1734+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: onp.AnyComplex128DType) -> _DIAArray[np.complex128]: ...
1735+
@overload # blocks: <unknown>, format: "dok", dtype: complex128-like
1736+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: onp.AnyComplex128DType) -> _DOKArray2D[np.complex128]: ...
1737+
@overload # blocks: <unknown>, format: "lil", dtype: complex128-like
1738+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.AnyComplex128DType) -> _LILArray[np.complex128]: ...
1739+
1740+
#
1741+
@overload # blocks: <known>, format: <default>, dtype: <known>
1742+
def block_array(blocks: _ToBlocksCanStackAs[_SCT0, _T], format: None = None, *, dtype: onp.ToDType[_SCT0]) -> _T: ...
1743+
@overload # blocks: <unknown>, format: "bsr", dtype: <known>
1744+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: onp.ToDType[_SCT0]) -> _BSRArray[_SCT0]: ...
1745+
@overload # blocks: <unknown>, format: "coo", dtype: <known>
1746+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: onp.ToDType[_SCT0]) -> _COOArray2D[_SCT0]: ...
1747+
@overload # blocks: <unknown>, format: "csc", dtype: <known>
1748+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: onp.ToDType[_SCT0]) -> _CSCArray[_SCT0]: ...
1749+
@overload # blocks: <unknown>, format: "csr", dtype: <known>
1750+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: onp.ToDType[_SCT0]) -> _CSRArray2D[_SCT0]: ...
1751+
@overload # blocks: <unknown>, format: "dia", dtype: <known>
1752+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: onp.ToDType[_SCT0]) -> _DIAArray[_SCT0]: ...
1753+
@overload # blocks: <unknown>, format: "dok", dtype: <known>
1754+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: onp.ToDType[_SCT0]) -> _DOKArray2D[_SCT0]: ...
1755+
@overload # blocks: <unknown>, format: "lil", dtype: <known>
1756+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.ToDType[_SCT0]) -> _LILArray[_SCT0]: ...
1757+
1758+
#
1759+
@overload # blocks: <known>, format: <default>, dtype: <unknown>
1760+
def block_array(blocks: _ToBlocksCanStackAs[Any, _T], format: None = None, *, dtype: npt.DTypeLike) -> _T: ...
1761+
@overload # blocks: <unknown>, format: "bsr", dtype: <unknown>
1762+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: npt.DTypeLike) -> _BSRArray: ...
1763+
@overload # blocks: <unknown>, format: "coo", dtype: <unknown>
1764+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: npt.DTypeLike) -> _COOArray2D: ...
1765+
@overload # blocks: <unknown>, format: "csc", dtype: <unknown>
1766+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: npt.DTypeLike) -> _CSCArray: ...
1767+
@overload # blocks: <unknown>, format: "csr", dtype: <unknown>
1768+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: npt.DTypeLike) -> _CSRArray2D: ...
1769+
@overload # blocks: <unknown>, format: "dia", dtype: <unknown>
1770+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: npt.DTypeLike) -> _DIAArray: ...
1771+
@overload # blocks: <unknown>, format: "dok", dtype: <unknown>
1772+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: npt.DTypeLike) -> _DOKArray2D: ...
1773+
@overload # blocks: <unknown>, format: "lil", dtype: <unknown>
1774+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: npt.DTypeLike) -> _LILArray: ...
16631775

16641776
###
1665-
# TODO(jorenham): Use `_CanStack` here, which requires a way to map array types to matrix types.
1666-
@overload # blocks: <array, known>, format: <default>, dtype: <default>
1667-
def bmat(blocks: Seq[Seq[sparray[_SCT]]], format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...
1668-
@overload # blocks: <array, known>, format: <otherwise>, dtype: <default>
1669-
def bmat(blocks: Seq[Seq[sparray[_SCT]]], format: _FmtNonCOO, dtype: None = None) -> _SpArray2D[_SCT]: ...
1670-
1671-
#
1672-
@overload # blocks: <matrix, known>, format: <default>, dtype: <default>
1673-
def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtCOO | None = None, dtype: None = None) -> coo_matrix[_SCT]: ...
1674-
@overload # blocks: <matrix, known>, format: <otherwise>, dtype: <default>
1675-
def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtNonCOO, dtype: None = None) -> _SpMatrix[_SCT]: ...
1777+
@overload # blocks: <known>, format: <default>, dtype: <default>
1778+
def bmat(blocks: _ToBlocksCanStack[_T], format: None = None, dtype: None = None) -> _T: ...
1779+
@overload # blocks: <matrix>, format: <otherwise>, dtype: <default>
1780+
def bmat(blocks:_ToBlocksSPMatrix[_SCT], format: _Format, dtype: None = None) -> _SpMatrix[_SCT]: ...
16761781

16771782
#
1678-
@overload # blocks: <unknown, unknown>, format: <default>, dtype: <known>
1679-
def bmat(
1680-
blocks: _ToBlocks, format: _FmtCOO | None = None, *, dtype: onp.ToDType[_SCT]
1681-
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1682-
@overload # blocks: <unknown, unknown>, format: <otherwise>, dtype: <known>
1683-
def bmat(blocks: _ToBlocks, format: _FmtNonCOO, dtype: onp.ToDType[_SCT]) -> _SpBase2D[_SCT]: ...
1783+
@overload # blocks: <known>, format: <default>, dtype: <known>
1784+
def bmat(blocks: _ToBlocksCanStackAs[_SCT0, _T], format: None = None, *, dtype: onp.ToDType[_SCT0]) -> _T: ...
1785+
@overload # blocks: <matrix>, format: <otherwise>, dtype: <known>
1786+
def bmat(blocks: _ToBlocksSPMatrix[_SCT0], format: _Format, dtype: onp.ToDType[_SCT0]) -> _SpMatrix[_SCT0]: ...
1787+
@overload # blocks: <unknown>, format: <otherwise>, dtype: <known>
1788+
def bmat(blocks: _ToBlocksUnkown, format: _Format, dtype: onp.ToDType[_SCT0]) -> _SpMatrix[_SCT0] | _SpArray2D[_SCT0]: ...
16841789

16851790
#
1686-
@overload # blocks: <unknown, unknown>, format: <default>, dtype: <unknown>
1687-
def bmat(
1688-
blocks: _ToBlocks[_SCT], format: _FmtCOO | None = None, dtype: npt.DTypeLike | None = None
1689-
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1690-
@overload # blocks: <unknown, unknown>, format: <otherwise>, dtype: <unknown>
1691-
def bmat(blocks: _ToBlocks[_SCT], format: _FmtNonCOO, dtype: npt.DTypeLike) -> _SpBase2D[_SCT]: ...
1791+
@overload # blocks: <known>, format: <default>, dtype: <unknown>
1792+
def bmat(blocks: _ToBlocksCanStackAs[Any, _T], format: None = None, *, dtype: npt.DTypeLike) -> _T: ...
1793+
@overload # blocks: <matrix>, format: <otherwise>, dtype: <unknown>
1794+
def bmat(blocks: _ToBlocksSPMatrix, format: _Format, *, dtype: npt.DTypeLike) -> _SpMatrix: ...
1795+
@overload # blocks: <unknown>, format: <otherwise>, dtype: <unknown>
1796+
def bmat(blocks: _ToBlocksUnkown, format: _Format, *, dtype: npt.DTypeLike) -> _SpMatrix | _SpArray2D: ...
16921797

16931798
###
16941799
# TODO(jorenham): Add support for non-COO formats.

0 commit comments

Comments
 (0)