Skip to content

Commit be391cf

Browse files
JulVandenBroeckjulvdb
authored andcommitted
Improved stubs for block_diag
1 parent e1f85fc commit be391cf

File tree

2 files changed

+208
-43
lines changed

2 files changed

+208
-43
lines changed

scipy-stubs/sparse/_construct.pyi

Lines changed: 195 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,27 +1750,27 @@ def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: onp.ToDType[_SC
17501750

17511751
#
17521752
@overload # blocks: <known, unknown>, format: <default>, dtype: <unknown>
1753-
def block_array(blocks: _ToBlocksCanStackAs[Any, _T], format: None = None, *, dtype: npt.DTypeLike) -> _T: ...
1753+
def block_array(blocks: _ToBlocksCanStackAs[Any, _T], format: None = None, *, dtype: npt.DTypeLike | None = None) -> _T: ...
17541754
@overload # blocks: <unknown, unknown>, format: "bsr", dtype: <unknown>
1755-
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: npt.DTypeLike) -> _BSRArray: ...
1755+
def block_array(blocks: _ToBlocksUnkown, format: _FmtBSR, dtype: npt.DTypeLike | None = None) -> _BSRArray: ...
17561756
@overload # blocks: <unknown, unknown>, format: "coo", dtype: <unknown>
1757-
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: npt.DTypeLike) -> _COOArray2D: ...
1757+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCOO, dtype: npt.DTypeLike | None = None) -> _COOArray2D: ...
17581758
@overload # blocks: <unknown, unknown>, format: "csc", dtype: <unknown>
1759-
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: npt.DTypeLike) -> _CSCArray: ...
1759+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSC, dtype: npt.DTypeLike | None = None) -> _CSCArray: ...
17601760
@overload # blocks: <unknown, unknown>, format: "csr", dtype: <unknown>
1761-
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: npt.DTypeLike) -> _CSRArray2D: ...
1761+
def block_array(blocks: _ToBlocksUnkown, format: _FmtCSR, dtype: npt.DTypeLike | None = None) -> _CSRArray2D: ...
17621762
@overload # blocks: <unknown, unknown>, format: "dia", dtype: <unknown>
1763-
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: npt.DTypeLike) -> _DIAArray: ...
1763+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDIA, dtype: npt.DTypeLike | None = None) -> _DIAArray: ...
17641764
@overload # blocks: <unknown, unknown>, format: "dok", dtype: <unknown>
1765-
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: npt.DTypeLike) -> _DOKArray2D: ...
1765+
def block_array(blocks: _ToBlocksUnkown, format: _FmtDOK, dtype: npt.DTypeLike | None = None) -> _DOKArray2D: ...
17661766
@overload # blocks: <unknown, unknown>, format: "lil", dtype: <unknown>
1767-
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: npt.DTypeLike) -> _LILArray: ...
1767+
def block_array(blocks: _ToBlocksUnkown, format: _FmtLIL, dtype: npt.DTypeLike | None = None) -> _LILArray: ...
17681768

17691769
###
17701770
@overload # blocks: <known, known>, format: <default>, dtype: <default>
17711771
def bmat(blocks: _ToBlocksCanStack[_T], format: None = None, dtype: None = None) -> _T: ...
17721772
@overload # blocks: <matrix, known>, format: <otherwise>, dtype: <default>
1773-
def bmat(blocks:_ToBlocksSPMatrix[_SCT], format: _Format, dtype: None = None) -> _SpMatrix[_SCT]: ...
1773+
def bmat(blocks: _ToBlocksSPMatrix[_SCT], format: _Format, dtype: None = None) -> _SpMatrix[_SCT]: ...
17741774

17751775
#
17761776
@overload # blocks: <known, known>, format: <default>, dtype: <known>
@@ -1789,69 +1789,223 @@ def bmat(blocks: _ToBlocksSPMatrix, format: _Format, *, dtype: npt.DTypeLike) ->
17891789
def bmat(blocks: _ToBlocksUnkown, format: _Format, *, dtype: npt.DTypeLike) -> _SpMatrix | _SpArray2D: ...
17901790

17911791
###
1792-
# TODO(jorenham): Add support for non-COO formats.
17931792
@overload # mats: <array, known>, format: <default>, dtype: None
17941793
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...
1795-
1796-
#
1797-
@overload # mats: <matrix, known>, format: <default>, dtype: None
1798-
def block_diag(mats: Iterable[spmatrix[_SCT]], format: _FmtCOO | None = None, dtype: None = None) -> coo_matrix[_SCT]: ...
1794+
@overload # mats: <array, known>, format: "bsr", dtype: None
1795+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtBSR, dtype: None = None) -> _BSRArray[_SCT]: ...
1796+
@overload # mats: <array, known>, format: "csc", dtype: None
1797+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtCSC, dtype: None = None) -> _CSCArray[_SCT]: ...
1798+
@overload # mats: <array, known>, format: "csr", dtype: None
1799+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtCSR, dtype: None = None) -> _CSRArray2D[_SCT]: ...
1800+
@overload # mats: <array, known>, format: "dia", dtype: None
1801+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtDIA, dtype: None = None) -> _DIAArray[_SCT]: ...
1802+
@overload # mats: <array, known>, format: "dok", dtype: None
1803+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtDOK, dtype: None = None) -> _DOKArray2D[_SCT]: ...
1804+
@overload # mats: <array, known>, format: "lil", dtype: None
1805+
def block_diag(mats: Iterable[sparray[_SCT]], format: _FmtLIL, dtype: None = None) -> _LILArray[_SCT]: ...
1806+
1807+
#
1808+
@overload # mats: <array, unknown>, format: <default>, dtype: bool-like
1809+
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyBoolDType) -> _COOArray2D[np.bool_]: ...
1810+
@overload # mats: <array, unknown>, format: "bsr", dtype: bool-like
1811+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyBoolDType) -> _BSRArray[np.bool_]: ...
1812+
@overload # mats: <array, unknown>, format: "csc", dtype: bool-like
1813+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyBoolDType) -> _CSCArray[np.bool_]: ...
1814+
@overload # mats: <array, unknown>, format: "csr", dtype: bool-like
1815+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyBoolDType) -> _CSRArray2D[np.bool_]: ...
1816+
@overload # mats: <array, unknown>, format: "dia", dtype: bool-like
1817+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyBoolDType) -> _DIAArray[np.bool_]: ...
1818+
@overload # mats: <array, unknown>, format: "dok", dtype: bool-like
1819+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyBoolDType) -> _DOKArray2D[np.bool_]: ...
1820+
@overload # mats: <array, unknown>, format: "lil", dtype: bool-like
1821+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyBoolDType) -> _LILArray[np.bool_]: ...
1822+
1823+
#
1824+
@overload # mats: <array, unknown>, format: <default>, dtype: int-like
1825+
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyIntDType) -> _COOArray2D[np.int_]: ...
1826+
@overload # mats: <array, unknown>, format: "bsr", dtype: int-like
1827+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyIntDType) -> _BSRArray[np.int_]: ...
1828+
@overload # mats: <array, unknown>, format: "csc", dtype: int-like
1829+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyIntDType) -> _CSCArray[np.int_]: ...
1830+
@overload # mats: <array, unknown>, format: "csr", dtype: int-like
1831+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyIntDType) -> _CSRArray2D[np.int_]: ...
1832+
@overload # mats: <array, unknown>, format: "dia", dtype: int-like
1833+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyIntDType) -> _DIAArray[np.int_]: ...
1834+
@overload # mats: <array, unknown>, format: "dok", dtype: int-like
1835+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyIntDType) -> _DOKArray2D[np.int_]: ...
1836+
@overload # mats: <array, unknown>, format: "lil", dtype: int-like
1837+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyIntDType) -> _LILArray[np.int_]: ...
1838+
1839+
#
1840+
@overload # mats: <array, unknown>, format: <default>, dtype: float64-like
1841+
def block_diag(
1842+
mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyFloat64DType
1843+
) -> _COOArray2D[np.float64]: ...
1844+
@overload # mats: <array, unknown>, format: "bsr", dtype: float64-like
1845+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyFloat64DType) -> _BSRArray[np.float64]: ...
1846+
@overload # mats: <array, unknown>, format: "csc", dtype: float64-like
1847+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyFloat64DType) -> _CSCArray[np.float64]: ...
1848+
@overload # mats: <array, unknown>, format: "csr", dtype: float64-like
1849+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyFloat64DType) -> _CSRArray2D[np.float64]: ...
1850+
@overload # mats: <array, unknown>, format: "dia", dtype: float64-like
1851+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyFloat64DType) -> _DIAArray[np.float64]: ...
1852+
@overload # mats: <array, unknown>, format: "dok", dtype: float64-like
1853+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyFloat64DType) -> _DOKArray2D[np.float64]: ...
1854+
@overload # mats: <array, unknown>, format: "lil", dtype: float64-like
1855+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyFloat64DType) -> _LILArray[np.float64]: ...
1856+
1857+
#
1858+
@overload # mats: <array, unknown>, format: <default>, dtype: complex128-like
1859+
def block_diag(
1860+
mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.AnyComplex128DType
1861+
) -> _COOArray2D[np.complex128]: ...
1862+
@overload # mats: <array, unknown>, format: "bsr", dtype: complex128-like
1863+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.AnyComplex128DType) -> _BSRArray[np.complex128]: ...
1864+
@overload # mats: <array, unknown>, format: "csc", dtype: complex128-like
1865+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.AnyComplex128DType) -> _CSCArray[np.complex128]: ...
1866+
@overload # mats: <array, unknown>, format: "csr", dtype: complex128-like
1867+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.AnyComplex128DType) -> _CSRArray2D[np.complex128]: ...
1868+
@overload # mats: <array, unknown>, format: "dia", dtype: complex128-like
1869+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.AnyComplex128DType) -> _DIAArray[np.complex128]: ...
1870+
@overload # mats: <array, unknown>, format: "dok", dtype: complex128-like
1871+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.AnyComplex128DType) -> _DOKArray2D[np.complex128]: ...
1872+
@overload # mats: <array, unknown>, format: "lil", dtype: complex128-like
1873+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.AnyComplex128DType) -> _LILArray[np.complex128]: ...
1874+
1875+
#
1876+
@overload # mats: <array, unknown>, format: <default>, dtype: <known>
1877+
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1878+
@overload # mats: <array, unknown>, format: "bsr", dtype: <known>
1879+
def block_diag(mats: Iterable[sparray], format: _FmtBSR, dtype: onp.ToDType[_SCT]) -> _BSRArray[_SCT]: ...
1880+
@overload # mats: <array, unknown>, format: "csc", dtype: <known>
1881+
def block_diag(mats: Iterable[sparray], format: _FmtCSC, dtype: onp.ToDType[_SCT]) -> _CSCArray[_SCT]: ...
1882+
@overload # mats: <array, unknown>, format: "csr", dtype: <known>
1883+
def block_diag(mats: Iterable[sparray], format: _FmtCSR, dtype: onp.ToDType[_SCT]) -> _CSRArray2D[_SCT]: ...
1884+
@overload # mats: <array, unknown>, format: "dia", dtype: <known>
1885+
def block_diag(mats: Iterable[sparray], format: _FmtDIA, dtype: onp.ToDType[_SCT]) -> _DIAArray[_SCT]: ...
1886+
@overload # mats: <array, unknown>, format: "dok", dtype: <known>
1887+
def block_diag(mats: Iterable[sparray], format: _FmtDOK, dtype: onp.ToDType[_SCT]) -> _DOKArray2D[_SCT]: ...
1888+
@overload # mats: <array, unknown>, format: "lil", dtype: <known>
1889+
def block_diag(mats: Iterable[sparray], format: _FmtLIL, dtype: onp.ToDType[_SCT]) -> _LILArray[_SCT]: ...
17991890

18001891
#
18011892
@overload # mats: <unknown, known>, format: <default>, dtype: None
18021893
def block_diag(
1803-
mats: Iterable[_spbase[_SCT] | onp.ArrayND[_SCT]], format: _FmtCOO | None = None, dtype: None = None
1894+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtCOO | None = None, dtype: None = None
18041895
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1896+
@overload # mats: <unknown, known>, format: "bsr", dtype: None
1897+
def block_diag(
1898+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtBSR, dtype: None = None
1899+
) -> _BSRArray[_SCT] | bsr_matrix[_SCT]: ...
1900+
@overload # mats: <unknown, known>, format: "csc", dtype: None
1901+
def block_diag(
1902+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtCSC, dtype: None = None
1903+
) -> _CSCArray[_SCT] | csc_matrix[_SCT]: ...
1904+
@overload # mats: <unknown, known>, format: "csr", dtype: None
1905+
def block_diag(
1906+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtCSR, dtype: None = None
1907+
) -> _CSRArray2D[_SCT] | csr_matrix[_SCT]: ...
1908+
@overload # mats: <unknown, known>, format: "dia", dtype: None
1909+
def block_diag(
1910+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtDIA, dtype: None = None
1911+
) -> _DIAArray[_SCT] | dia_matrix[_SCT]: ...
1912+
@overload # mats: <unknown, known>, format: "dok", dtype: None
1913+
def block_diag(
1914+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtDOK, dtype: None = None
1915+
) -> _DOKArray2D[_SCT] | dok_matrix[_SCT]: ...
1916+
@overload # mats: <unknown, known>, format: "lil", dtype: None
1917+
def block_diag(
1918+
mats: Iterable[_spbase[_SCT] | onp.ToArrayND[_SCT]], format: _FmtLIL, dtype: None = None
1919+
) -> _LILArray[_SCT] | lil_matrix[_SCT]: ...
18051920

18061921
#
1807-
@overload # mats: <array, unknown>, format: <default>, dtype: <known> (keyword)
1808-
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None = None, *, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1809-
@overload # mats: <array, unknown>, format: <default>, dtype: <known> (positional)
1810-
def block_diag(mats: Iterable[sparray], format: _FmtCOO | None, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
1811-
1812-
#
1813-
@overload # mats: <matrix, unknown>, format: <default>, dtype: <known> (keyword)
1922+
@overload # mats: <unknown, unknown>, format: <default>, dtype: <known>
18141923
def block_diag(
1815-
mats: Iterable[spmatrix | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1924+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
18161925
format: _FmtCOO | None = None,
18171926
*,
18181927
dtype: onp.ToDType[_SCT],
1819-
) -> coo_matrix[_SCT]: ...
1820-
@overload # mats: <matrix, unknown>, format: <default>, dtype: <known> (positional)
1928+
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1929+
@overload # mats: <unknown, unknown>, format: "bsr", dtype: <known>
18211930
def block_diag(
1822-
mats: Iterable[spmatrix | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1823-
format: _FmtCOO | None,
1931+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1932+
format: _FmtBSR,
18241933
dtype: onp.ToDType[_SCT],
1825-
) -> coo_matrix[_SCT]: ...
1826-
1827-
#
1828-
@overload # mats: <unknown, unknown>, format: <default>, dtype: <known> (keyword)
1934+
) -> _BSRArray[_SCT] | bsr_matrix[_SCT]: ...
1935+
@overload # mats: <unknown, unknown>, format: "csc", dtype: <known>
18291936
def block_diag(
18301937
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1831-
format: _FmtCOO | None = None,
1832-
*,
1938+
format: _FmtCSC,
18331939
dtype: onp.ToDType[_SCT],
1834-
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1835-
@overload # mats: <unknown, unknown>, format: <default>, dtype: <known> (positional)
1940+
) -> _CSCArray[_SCT] | csc_matrix[_SCT]: ...
1941+
@overload # mats: <unknown, unknown>, format: "csr", dtype: <known>
18361942
def block_diag(
18371943
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1838-
format: _FmtCOO | None,
1944+
format: _FmtCSR,
18391945
dtype: onp.ToDType[_SCT],
1840-
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
1946+
) -> _CSRArray2D[_SCT] | csr_matrix[_SCT]: ...
1947+
@overload # mats: <unknown, unknown>, format: "dia", dtype: <known>
1948+
def block_diag(
1949+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1950+
format: _FmtDIA,
1951+
dtype: onp.ToDType[_SCT],
1952+
) -> _DIAArray[_SCT] | dia_matrix[_SCT]: ...
1953+
@overload # mats: <unknown, unknown>, format: "dok", dtype: <known>
1954+
def block_diag(
1955+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1956+
format: _FmtDOK,
1957+
dtype: onp.ToDType[_SCT],
1958+
) -> _DOKArray2D[_SCT] | dok_matrix[_SCT]: ...
1959+
@overload # mats: <unknown, unknown>, format: "lil", dtype: <known>
1960+
def block_diag(
1961+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1962+
format: _FmtLIL,
1963+
dtype: onp.ToDType[_SCT],
1964+
) -> _LILArray[_SCT] | lil_matrix[_SCT]: ...
18411965

18421966
#
18431967
@overload # mats: <unknown, unknown>, format: <default>, dtype: <unknown>
18441968
def block_diag(
18451969
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
18461970
format: _FmtCOO | None = None,
18471971
dtype: npt.DTypeLike | None = None,
1848-
) -> _COOArray2D[_SCT] | coo_matrix[Any]: ...
1849-
@overload # mats: <unknown, unknown>, format: <otherwise>, dtype: <unknown>
1972+
) -> _COOArray2D | coo_matrix: ...
1973+
@overload # mats: <unknown, unknown>, format: "bsr", dtype: <unknown>
1974+
def block_diag(
1975+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1976+
format: _FmtBSR,
1977+
dtype: npt.DTypeLike | None = None,
1978+
) -> _BSRArray | bsr_matrix: ...
1979+
@overload # mats: <unknown, unknown>, format: "csc", dtype: <unknown>
1980+
def block_diag(
1981+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1982+
format: _FmtCSC,
1983+
dtype: npt.DTypeLike | None = None,
1984+
) -> _CSCArray | csc_matrix: ...
1985+
@overload # mats: <unknown, unknown>, format: "csr", dtype: <unknown>
1986+
def block_diag(
1987+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1988+
format: _FmtCSR,
1989+
dtype: npt.DTypeLike | None = None,
1990+
) -> _CSRArray2D | csr_matrix: ...
1991+
@overload # mats: <unknown, unknown>, format: "dia", dtype: <unknown>
18501992
def block_diag(
18511993
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
1852-
format: _FmtNonCOO | None = None,
1994+
format: _FmtDIA,
1995+
dtype: npt.DTypeLike | None = None,
1996+
) -> _DIAArray | dia_matrix: ...
1997+
@overload # mats: <unknown, unknown>, format: "dok", dtype: <unknown>
1998+
def block_diag(
1999+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
2000+
format: _FmtDOK,
2001+
dtype: npt.DTypeLike | None = None,
2002+
) -> _DOKArray2D | dok_matrix: ...
2003+
@overload # mats: <unknown, unknown>, format: "lil", dtype: <unknown>
2004+
def block_diag(
2005+
mats: Iterable[_spbase | onp.ArrayND[_Numeric] | complex | Seq[onp.ToComplex] | Seq[onp.ToComplex1D]],
2006+
format: _FmtLIL,
18532007
dtype: npt.DTypeLike | None = None,
1854-
) -> Incomplete: ...
2008+
) -> _LILArray | lil_matrix: ...
18552009

18562010
###
18572011
@overload # shape: T, format: <default>, dtype: <default>

tests/sparse/test_construct.pyi

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,20 @@ assert_type(sparse.bmat([[dia_arr]], dtype=np.int_), sparse.coo_array[np.int_, t
283283
assert_type(sparse.bmat([[dok_arr], [None]], dtype=np.complex64), sparse.coo_array[np.complex64, tuple[int, int]])
284284

285285
# block_diag
286-
assert_type(sparse.block_diag([any_mat, any_mat]), sparse.coo_matrix[ScalarType])
287286
assert_type(sparse.block_diag([any_arr, any_arr]), sparse.coo_array[ScalarType, tuple[int, int]])
288-
assert_type(sparse.block_diag([any_arr, any_mat]), sparse.coo_array[ScalarType, tuple[int, int]] | sparse.coo_matrix[ScalarType])
287+
assert_type(sparse.block_diag([any_arr, any_mat]), sparse.coo_matrix[ScalarType] | sparse.coo_array[ScalarType, tuple[int, int]])
288+
assert_type(sparse.block_diag([any_arr, any_arr], dtype="bool"), sparse.coo_array[np.bool_, tuple[int, int]])
289+
assert_type(sparse.block_diag([any_arr, any_arr], dtype=int), sparse.coo_array[np.int_, tuple[int, int]])
290+
assert_type(
291+
sparse.block_diag([any_arr, any_mat], dtype=np.complex64),
292+
sparse.coo_matrix[np.complex64] | sparse.coo_array[np.complex64, tuple[int, int]],
293+
)
294+
assert_type(sparse.block_diag([any_arr, any_arr], format="bsr"), sparse.bsr_array[ScalarType])
295+
assert_type(sparse.block_diag([any_arr, any_arr], format="csc", dtype=float), sparse.csc_array[np.float64])
296+
assert_type(sparse.block_diag([any_arr, any_arr], format="csr", dtype=complex), sparse.csr_array[np.complex128, tuple[int, int]])
297+
assert_type(sparse.block_diag([any_arr, any_arr], format="dia", dtype=np.int32), sparse.dia_array[np.int32])
298+
assert_type(sparse.block_diag([any_arr, any_arr], format="dok", dtype="bool"), sparse.dok_array[np.bool_, tuple[int, int]])
299+
assert_type(sparse.block_diag([any_arr, any_arr], format="lil", dtype=np.complex64), sparse.lil_array[np.complex64])
289300

290301
###
291302
# random_array

0 commit comments

Comments
 (0)