Skip to content

Commit cdd0117

Browse files
Improved stubs for block_diag
1 parent be2ca0d commit cdd0117

File tree

2 files changed

+208
-43
lines changed

2 files changed

+208
-43
lines changed

scipy-stubs/sparse/_construct.pyi

Lines changed: 195 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,27 +1757,27 @@ def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.ToDType[_SC
17571757

17581758
#
17591759
@overload # blocks: <known, unknown>, format: <default>, dtype: <unknown>
1760-
def block_array(blocks: _ToBlocksCanStackAs[Any, _T], format: None = None, *, dtype: npt.DTypeLike) -> _T: ...
1760+
def block_array(blocks: _ToBlocksCanStackAs[Any, _T], format: None = None, *, dtype: npt.DTypeLike | None = None) -> _T: ...
17611761
@overload # blocks: <unknown, unknown>, format: "bsr", dtype: <unknown>
1762-
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: npt.DTypeLike) -> _BSRArray: ...
1762+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: npt.DTypeLike | None = None) -> _BSRArray: ...
17631763
@overload # blocks: <unknown, unknown>, format: "coo", dtype: <unknown>
1764-
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: npt.DTypeLike) -> _COOArray2D: ...
1764+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: npt.DTypeLike | None = None) -> _COOArray2D: ...
17651765
@overload # blocks: <unknown, unknown>, format: "csc", dtype: <unknown>
1766-
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: npt.DTypeLike) -> _CSCArray: ...
1766+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: npt.DTypeLike | None = None) -> _CSCArray: ...
17671767
@overload # blocks: <unknown, unknown>, format: "csr", dtype: <unknown>
1768-
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: npt.DTypeLike) -> _CSRArray2D: ...
1768+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: npt.DTypeLike | None = None) -> _CSRArray2D: ...
17691769
@overload # blocks: <unknown, unknown>, format: "dia", dtype: <unknown>
1770-
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: npt.DTypeLike) -> _DIAArray: ...
1770+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: npt.DTypeLike | None = None) -> _DIAArray: ...
17711771
@overload # blocks: <unknown, unknown>, format: "dok", dtype: <unknown>
1772-
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: npt.DTypeLike) -> _DOKArray2D: ...
1772+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: npt.DTypeLike | None = None) -> _DOKArray2D: ...
17731773
@overload # blocks: <unknown, unknown>, format: "lil", dtype: <unknown>
1774-
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: npt.DTypeLike) -> _LILArray: ...
1774+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: npt.DTypeLike | None = None) -> _LILArray: ...
17751775

17761776
###
17771777
@overload # blocks: <known, known>, format: <default>, dtype: <default>
17781778
def bmat(blocks: _ToBlocksCanStack[_T], format: None = None, dtype: None = None) -> _T: ...
17791779
@overload # blocks: <matrix, known>, format: <otherwise>, dtype: <default>
1780-
def bmat(blocks:_ToBlocksSPMatrix[_SCT], format: _Format, dtype: None = None) -> _SpMatrix[_SCT]: ...
1780+
def bmat(blocks: _ToBlocksSPMatrix[_SCT], format: _Format, dtype: None = None) -> _SpMatrix[_SCT]: ...
17811781

17821782
#
17831783
@overload # blocks: <known, known>, format: <default>, dtype: <known>
@@ -1796,69 +1796,223 @@ def bmat(blocks: _ToBlocksSPMatrix, format: _Format, *, dtype: npt.DTypeLike) ->
17961796
def bmat(blocks: _ToBlocksUnkown, format: _Format, *, dtype: npt.DTypeLike) -> _SpMatrix | _SpArray2D: ...
17971797

17981798
###
1799-
# TODO(jorenham): Add support for non-COO formats.
18001799
@overload # mats: <array, known>, format: <default>, dtype: None
18011800
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...
1802-
1803-
#
1804-
@overload # mats: <matrix, known>, format: <default>, dtype: None
1805-
def block_diag(mats: Iterable[spmatrix[_SCT]], format: _FmtCOO | None = None, dtype: None = None) -> coo_matrix[_SCT]: ...
1801+
@overload # mats: <array, known>, format: "bsr", dtype: None
1802+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT]: ...
1803+
@overload # mats: <array, known>, format: "csc", dtype: None
1804+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT]: ...
1805+
@overload # mats: <array, known>, format: "csr", dtype: None
1806+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT]: ...
1807+
@overload # mats: <array, known>, format: "dia", dtype: None
1808+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT]: ...
1809+
@overload # mats: <array, known>, format: "dok", dtype: None
1810+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT]: ...
1811+
@overload # mats: <array, known>, format: "lil", dtype: None
1812+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT]: ...
1813+
1814+
#
1815+
@overload # mats: <array, unknown>, format: <default>, dtype: bool-like
1816+
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyBoolDType) -> _COOArray2D[np.bool_]: ...
1817+
@overload # mats: <array, unknown>, format: "bsr", dtype: bool-like
1818+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyBoolDType) -> _BSRArray[np.bool_]: ...
1819+
@overload # mats: <array, unknown>, format: "csc", dtype: bool-like
1820+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyBoolDType) -> _CSCArray[np.bool_]: ...
1821+
@overload # mats: <array, unknown>, format: "csr", dtype: bool-like
1822+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyBoolDType) -> _CSRArray2D[np.bool_]: ...
1823+
@overload # mats: <array, unknown>, format: "dia", dtype: bool-like
1824+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyBoolDType) -> _DIAArray[np.bool_]: ...
1825+
@overload # mats: <array, unknown>, format: "dok", dtype: bool-like
1826+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyBoolDType) -> _DOKArray2D[np.bool_]: ...
1827+
@overload # mats: <array, unknown>, format: "lil", dtype: bool-like
1828+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyBoolDType) -> _LILArray[np.bool_]: ...
1829+
1830+
#
1831+
@overload # mats: <array, unknown>, format: <default>, dtype: int-like
1832+
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyIntDType) -> _COOArray2D[np.int_]: ...
1833+
@overload # mats: <array, unknown>, format: "bsr", dtype: int-like
1834+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyIntDType) -> _BSRArray[np.int_]: ...
1835+
@overload # mats: <array, unknown>, format: "csc", dtype: int-like
1836+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyIntDType) -> _CSCArray[np.int_]: ...
1837+
@overload # mats: <array, unknown>, format: "csr", dtype: int-like
1838+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyIntDType) -> _CSRArray2D[np.int_]: ...
1839+
@overload # mats: <array, unknown>, format: "dia", dtype: int-like
1840+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyIntDType) -> _DIAArray[np.int_]: ...
1841+
@overload # mats: <array, unknown>, format: "dok", dtype: int-like
1842+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyIntDType) -> _DOKArray2D[np.int_]: ...
1843+
@overload # mats: <array, unknown>, format: "lil", dtype: int-like
1844+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyIntDType) -> _LILArray[np.int_]: ...
1845+
1846+
#
1847+
@overload # mats: <array, unknown>, format: <default>, dtype: float64-like
1848+
def block_diag(
1849+
mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyFloat64DType
1850+
) -> _COOArray2D[np.float64]: ...
1851+
@overload # mats: <array, unknown>, format: "bsr", dtype: float64-like
1852+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyFloat64DType) -> _BSRArray[np.float64]: ...
1853+
@overload # mats: <array, unknown>, format: "csc", dtype: float64-like
1854+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyFloat64DType) -> _CSCArray[np.float64]: ...
1855+
@overload # mats: <array, unknown>, format: "csr", dtype: float64-like
1856+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyFloat64DType) -> _CSRArray2D[np.float64]: ...
1857+
@overload # mats: <array, unknown>, format: "dia", dtype: float64-like
1858+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyFloat64DType) -> _DIAArray[np.float64]: ...
1859+
@overload # mats: <array, unknown>, format: "dok", dtype: float64-like
1860+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyFloat64DType) -> _DOKArray2D[np.float64]: ...
1861+
@overload # mats: <array, unknown>, format: "lil", dtype: float64-like
1862+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyFloat64DType) -> _LILArray[np.float64]: ...
1863+
1864+
#
1865+
@overload # mats: <array, unknown>, format: <default>, dtype: complex128-like
1866+
def block_diag(
1867+
mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyComplex128DType
1868+
) -> _COOArray2D[np.complex128]: ...
1869+
@overload # mats: <array, unknown>, format: "bsr", dtype: complex128-like
1870+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyComplex128DType) -> _BSRArray[np.complex128]: ...
1871+
@overload # mats: <array, unknown>, format: "csc", dtype: complex128-like
1872+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyComplex128DType) -> _CSCArray[np.complex128]: ...
1873+
@overload # mats: <array, unknown>, format: "csr", dtype: complex128-like
1874+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyComplex128DType) -> _CSRArray2D[np.complex128]: ...
1875+
@overload # mats: <array, unknown>, format: "dia", dtype: complex128-like
1876+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyComplex128DType) -> _DIAArray[np.complex128]: ...
1877+
@overload # mats: <array, unknown>, format: "dok", dtype: complex128-like
1878+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyComplex128DType) -> _DOKArray2D[np.complex128]: ...
1879+
@overload # mats: <array, unknown>, format: "lil", dtype: complex128-like
1880+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyComplex128DType) -> _LILArray[np.complex128]: ...
1881+
1882+
#
1883+
@overload # mats: <array, unknown>, format: <default>, dtype: <known>
1884+
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1885+
@overload # mats: <array, unknown>, format: "bsr", dtype: <known>
1886+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.ToDType[_SCT]) -> _BSRArray[_SCT]: ...
1887+
@overload # mats: <array, unknown>, format: "csc", dtype: <known>
1888+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.ToDType[_SCT]) -> _CSCArray[_SCT]: ...
1889+
@overload # mats: <array, unknown>, format: "csr", dtype: <known>
1890+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.ToDType[_SCT]) -> _CSRArray2D[_SCT]: ...
1891+
@overload # mats: <array, unknown>, format: "dia", dtype: <known>
1892+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.ToDType[_SCT]) -> _DIAArray[_SCT]: ...
1893+
@overload # mats: <array, unknown>, format: "dok", dtype: <known>
1894+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.ToDType[_SCT]) -> _DOKArray2D[_SCT]: ...
1895+
@overload # mats: <array, unknown>, format: "lil", dtype: <known>
1896+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.ToDType[_SCT]) -> _LILArray[_SCT]: ...
18061897

18071898
#
18081899
@overload # mats: <unknown, known>, format: <default>, dtype: None
18091900
def block_diag(
1810-
mats: Iterable[_spbase[_SCT] | onp.ArrayND[_SCT]], format: _FmtCOO | None = None, dtype: None = None
1901+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtCOO | None = None, dtype: None = None
18111902
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1903+
@overload # mats: <unknown, known>, format: "bsr", dtype: None
1904+
def block_diag(
1905+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtBSR, dtype: None = None
1906+
) -> _BSRArray[_SCT] | bsr_matrix[_SCT]: ...
1907+
@overload # mats: <unknown, known>, format: "csc", dtype: None
1908+
def block_diag(
1909+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtCSC, dtype: None = None
1910+
) -> _CSCArray[_SCT] | csc_matrix[_SCT]: ...
1911+
@overload # mats: <unknown, known>, format: "csr", dtype: None
1912+
def block_diag(
1913+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtCSR, dtype: None = None
1914+
) -> _CSRArray2D[_SCT] | csr_matrix[_SCT]: ...
1915+
@overload # mats: <unknown, known>, format: "dia", dtype: None
1916+
def block_diag(
1917+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtDIA, dtype: None = None
1918+
) -> _DIAArray[_SCT] | dia_matrix[_SCT]: ...
1919+
@overload # mats: <unknown, known>, format: "dok", dtype: None
1920+
def block_diag(
1921+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtDOK, dtype: None = None
1922+
) -> _DOKArray2D[_SCT] | dok_matrix[_SCT]: ...
1923+
@overload # mats: <unknown, known>, format: "lil", dtype: None
1924+
def block_diag(
1925+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtLIL, dtype: None = None
1926+
) -> _LILArray[_SCT] | lil_matrix[_SCT]: ...
18121927

18131928
#
1814-
@overload # mats: <array, unknown>, format: <default>, dtype: <known> (keyword)
1815-
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1816-
@overload # mats: <array, unknown>, format: <default>, dtype: <known> (positional)
1817-
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1818-
1819-
#
1820-
@overload # mats: <matrix, unknown>, format: <default>, dtype: <known> (keyword)
1929+
@overload # mats: <unknown, unknown>, format: <default>, dtype: <known>
18211930
def block_diag(
1822-
mats: Iterable[spmatrix | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1931+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
18231932
format: _FmtCOO | None = None,
18241933
*,
18251934
dtype: onp.ToDType[_SCT],
1826-
) -> coo_matrix[_SCT]: ...
1827-
@overload # mats: <matrix, unknown>, format: <default>, dtype: <known> (positional)
1935+
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1936+
@overload # mats: <unknown, unknown>, format: "bsr", dtype: <known>
18281937
def block_diag(
1829-
mats: Iterable[spmatrix | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1830-
format: _FmtCOO | None,
1938+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1939+
format: _FmtBSR,
18311940
dtype: onp.ToDType[_SCT],
1832-
) -> coo_matrix[_SCT]: ...
1833-
1834-
#
1835-
@overload # mats: <unknown, unknown>, format: <default>, dtype: <known> (keyword)
1941+
) -> _BSRArray[_SCT] | bsr_matrix[_SCT]: ...
1942+
@overload # mats: <unknown, unknown>, format: "csc", dtype: <known>
18361943
def block_diag(
18371944
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1838-
format: _FmtCOO | None = None,
1839-
*,
1945+
format: _FmtCSC,
18401946
dtype: onp.ToDType[_SCT],
1841-
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1842-
@overload # mats: <unknown, unknown>, format: <default>, dtype: <known> (positional)
1947+
) -> _CSCArray[_SCT] | csc_matrix[_SCT]: ...
1948+
@overload # mats: <unknown, unknown>, format: "csr", dtype: <known>
18431949
def block_diag(
18441950
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1845-
format: _FmtCOO | None,
1951+
format: _FmtCSR,
18461952
dtype: onp.ToDType[_SCT],
1847-
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1953+
) -> _CSRArray2D[_SCT] | csr_matrix[_SCT]: ...
1954+
@overload # mats: <unknown, unknown>, format: "dia", dtype: <known>
1955+
def block_diag(
1956+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1957+
format: _FmtDIA,
1958+
dtype: onp.ToDType[_SCT],
1959+
) -> _DIAArray[_SCT] | dia_matrix[_SCT]: ...
1960+
@overload # mats: <unknown, unknown>, format: "dok", dtype: <known>
1961+
def block_diag(
1962+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1963+
format: _FmtDOK,
1964+
dtype: onp.ToDType[_SCT],
1965+
) -> _DOKArray2D[_SCT] | dok_matrix[_SCT]: ...
1966+
@overload # mats: <unknown, unknown>, format: "lil", dtype: <known>
1967+
def block_diag(
1968+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1969+
format: _FmtLIL,
1970+
dtype: onp.ToDType[_SCT],
1971+
) -> _LILArray[_SCT] | lil_matrix[_SCT]: ...
18481972

18491973
#
18501974
@overload # mats: <unknown, unknown>, format: <default>, dtype: <unknown>
18511975
def block_diag(
18521976
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
18531977
format: _FmtCOO | None = None,
18541978
dtype: npt.DTypeLike | None = None,
1855-
) -> _COOArray2D[_SCT] | coo_matrix[Any]: ...
1856-
@overload # mats: <unknown, unknown>, format: <otherwise>, dtype: <unknown>
1979+
) -> _COOArray2D | coo_matrix: ...
1980+
@overload # mats: <unknown, unknown>, format: "bsr", dtype: <unknown>
1981+
def block_diag(
1982+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1983+
format: _FmtBSR,
1984+
dtype: npt.DTypeLike | None = None,
1985+
) -> _BSRArray | bsr_matrix: ...
1986+
@overload # mats: <unknown, unknown>, format: "csc", dtype: <unknown>
1987+
def block_diag(
1988+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1989+
format: _FmtCSC,
1990+
dtype: npt.DTypeLike | None = None,
1991+
) -> _CSCArray | csc_matrix: ...
1992+
@overload # mats: <unknown, unknown>, format: "csr", dtype: <unknown>
1993+
def block_diag(
1994+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1995+
format: _FmtCSR,
1996+
dtype: npt.DTypeLike | None = None,
1997+
) -> _CSRArray2D | csr_matrix: ...
1998+
@overload # mats: <unknown, unknown>, format: "dia", dtype: <unknown>
18571999
def block_diag(
18582000
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1859-
format: _FmtNonCOO | None = None,
2001+
format: _FmtDIA,
2002+
dtype: npt.DTypeLike | None = None,
2003+
) -> _DIAArray | dia_matrix: ...
2004+
@overload # mats: <unknown, unknown>, format: "dok", dtype: <unknown>
2005+
def block_diag(
2006+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
2007+
format: _FmtDOK,
2008+
dtype: npt.DTypeLike | None = None,
2009+
) -> _DOKArray2D | dok_matrix: ...
2010+
@overload # mats: <unknown, unknown>, format: "lil", dtype: <unknown>
2011+
def block_diag(
2012+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
2013+
format: _FmtLIL,
18602014
dtype: npt.DTypeLike | None = None,
1861-
) -> Incomplete: ...
2015+
) -> _LILArray | lil_matrix: ...
18622016

18632017
###
18642018
@overload # shape: T, format: <default>, dtype: <default>

tests/sparse/test_construct.pyi

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,20 @@ assert_type(sparse.bmat([[dia_arr]], dtype=np.int_), sparse.coo_array[np.int_, t
283283
assert_type(sparse.bmat([[dok_arr], [None]], dtype=np.complex64), sparse.coo_array[np.complex64, tuple[int, int]])
284284

285285
# block_diag
286-
assert_type(sparse.block_diag([any_mat, any_mat]), sparse.coo_matrix[ScalarType])
287286
assert_type(sparse.block_diag([any_arr, any_arr]), sparse.coo_array[ScalarType, tuple[int, int]])
288-
assert_type(sparse.block_diag([any_arr, any_mat]), sparse.coo_array[ScalarType, tuple[int, int]] | sparse.coo_matrix[ScalarType])
287+
assert_type(sparse.block_diag([any_arr, any_mat]), sparse.coo_matrix[ScalarType] | sparse.coo_array[ScalarType, tuple[int, int]])
288+
assert_type(sparse.block_diag([any_arr, any_arr], dtype="bool"), sparse.coo_array[np.bool_, tuple[int, int]])
289+
assert_type(sparse.block_diag([any_arr, any_arr], dtype=int), sparse.coo_array[np.int_, tuple[int, int]])
290+
assert_type(
291+
sparse.block_diag([any_arr, any_mat], dtype=np.complex64),
292+
sparse.coo_matrix[np.complex64] | sparse.coo_array[np.complex64, tuple[int, int]],
293+
)
294+
assert_type(sparse.block_diag([any_arr, any_arr], format="bsr"), sparse.bsr_array[ScalarType])
295+
assert_type(sparse.block_diag([any_arr, any_arr], format="csc", dtype=float), sparse.csc_array[np.float64])
296+
assert_type(sparse.block_diag([any_arr, any_arr], format="csr", dtype=complex), sparse.csr_array[np.complex128, tuple[int, int]])
297+
assert_type(sparse.block_diag([any_arr, any_arr], format="dia", dtype=np.int32), sparse.dia_array[np.int32])
298+
assert_type(sparse.block_diag([any_arr, any_arr], format="dok", dtype="bool"), sparse.dok_array[np.bool_, tuple[int, int]])
299+
assert_type(sparse.block_diag([any_arr, any_arr], format="lil", dtype=np.complex64), sparse.lil_array[np.complex64])
289300

290301
###
291302
# random_array

0 commit comments

Comments
 (0)