Skip to content

Commit c2f8b79

Browse files
authored
🏷️ _lib: stub _ccallback_c and _array_api[_compat_vendor] (#713)
2 parents 907fff9 + 39ce795 commit c2f8b79

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ lint = [
5353
type = [
5454
{ include-group = "extras" },
5555
{ include-group = "ci" },
56+
"array-api-compat==1.12.0", # bundled as `scipy._lib.array_api_compat`
5657
"basedpyright>=1.29.5",
5758
"mypy>=1.16.1",
5859
"orjson>=3.10.18; python_version<'3.14'", # used by mypy

scipy-stubs/_lib/_array_api.pyi

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import dataclasses
2+
import re
3+
from _typeshed import Incomplete
4+
from collections.abc import Generator, Sequence
5+
from contextlib import contextmanager
6+
from types import ModuleType
7+
from typing import Any, Final, Literal, TypeAlias
8+
9+
from array_api_compat import (
10+
device as xp_device, # pyright: ignore[reportUnknownVariableType]
11+
is_array_api_strict_namespace as is_array_api_strict,
12+
is_cupy_namespace as is_cupy,
13+
is_jax_namespace as is_jax,
14+
is_lazy_array as is_lazy_array,
15+
is_numpy_namespace as is_numpy,
16+
is_torch_namespace as is_torch,
17+
size as xp_size,
18+
)
19+
20+
__all__ = [
21+
"SCIPY_ARRAY_API",
22+
"SCIPY_DEVICE",
23+
"_asarray",
24+
"array_namespace",
25+
"assert_almost_equal",
26+
"assert_array_almost_equal",
27+
"default_xp",
28+
"eager_warns",
29+
"is_array_api_strict",
30+
"is_complex",
31+
"is_cupy",
32+
"is_jax",
33+
"is_lazy_array",
34+
"is_marray",
35+
"is_numpy",
36+
"is_torch",
37+
"scipy_namespace_for",
38+
"xp_assert_close",
39+
"xp_assert_equal",
40+
"xp_assert_less",
41+
"xp_capabilities",
42+
"xp_copy",
43+
"xp_device",
44+
"xp_promote",
45+
"xp_ravel",
46+
"xp_result_type",
47+
"xp_size",
48+
"xp_unsupported_param_msg",
49+
"xp_vector_norm",
50+
]
51+
52+
SCIPY_ARRAY_API: Final[str | Literal[False]] = ...
53+
SCIPY_DEVICE: Final[str] = ...
54+
55+
Array: TypeAlias = Incomplete
56+
57+
def array_namespace(*arrays: Array) -> ModuleType: ...
58+
def _asarray(
59+
array: Any,
60+
dtype: Any = None,
61+
order: Literal["K", "A", "C", "F"] | None = None,
62+
copy: bool | None = None,
63+
*,
64+
xp: ModuleType | None = None,
65+
check_finite: bool = False,
66+
subok: bool = False,
67+
) -> Array: ...
68+
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array: ...
69+
@contextmanager
70+
def default_xp(xp: ModuleType) -> Generator[None]: ...
71+
def eager_warns(
72+
x: Array, warning_type: type[Warning] | tuple[type[Warning], ...], match: str | re.Pattern[str] | None = None
73+
) -> Incomplete: ... # _pytest.recwarn.WarningsChecker
74+
def xp_assert_equal(
75+
actual: Incomplete,
76+
desired: Incomplete,
77+
*,
78+
check_namespace: bool = True,
79+
check_dtype: bool = True,
80+
check_shape: bool = True,
81+
check_0d: bool = True,
82+
err_msg: str = "",
83+
xp: ModuleType | None = None,
84+
) -> None: ...
85+
def xp_assert_close(
86+
actual: Incomplete,
87+
desired: Incomplete,
88+
*,
89+
rtol: Incomplete | None = None,
90+
atol: int = 0,
91+
check_namespace: bool = True,
92+
check_dtype: bool = True,
93+
check_shape: bool = True,
94+
check_0d: bool = True,
95+
err_msg: str = "",
96+
xp: ModuleType | None = None,
97+
) -> None: ...
98+
def xp_assert_less(
99+
actual: Incomplete,
100+
desired: Incomplete,
101+
*,
102+
check_namespace: bool = True,
103+
check_dtype: bool = True,
104+
check_shape: bool = True,
105+
check_0d: bool = True,
106+
err_msg: str = "",
107+
verbose: bool = True,
108+
xp: ModuleType | None = None,
109+
) -> None: ...
110+
def assert_array_almost_equal(
111+
actual: Incomplete, desired: Incomplete, decimal: int = 6, *args: Incomplete, **kwds: Incomplete
112+
) -> None: ...
113+
def assert_almost_equal(
114+
actual: Incomplete, desired: Incomplete, decimal: int = 7, *args: Incomplete, **kwds: Incomplete
115+
) -> None: ...
116+
def xp_unsupported_param_msg(param: Incomplete) -> str: ...
117+
def is_complex(x: Array, xp: ModuleType) -> bool: ...
118+
def scipy_namespace_for(xp: ModuleType) -> ModuleType | None: ...
119+
def xp_vector_norm(
120+
x: Array,
121+
/,
122+
*,
123+
axis: int | tuple[int, ...] | None = None,
124+
keepdims: bool = False,
125+
ord: float = 2,
126+
xp: ModuleType | None = None,
127+
) -> Array: ...
128+
def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array: ...
129+
def xp_result_type(*args: Incomplete, force_floating: bool = False, xp: ModuleType) -> type: ...
130+
def xp_promote(*args: Incomplete, broadcast: bool = False, force_floating: bool = False, xp: ModuleType | None) -> Array: ...
131+
def is_marray(xp: ModuleType) -> bool: ...
132+
133+
@dataclasses.dataclass(repr=False)
134+
class _XPSphinxCapability:
135+
cpu: bool | None
136+
gpu: bool | None
137+
warnings: list[str] = ...
138+
139+
def _render(self, /, value: object) -> str: ...
140+
141+
def xp_capabilities(
142+
*,
143+
capabilities_table: Incomplete | None = None,
144+
skip_backends: Sequence[tuple[str, str]] = (),
145+
xfail_backends: Sequence[tuple[str, str]] = (),
146+
cpu_only: bool = False,
147+
np_only: bool = False,
148+
reason: str | None = None,
149+
exceptions: Sequence[str] = (),
150+
warnings: Sequence[tuple[str, str]] = (),
151+
allow_dask_compute: bool = False,
152+
jax_jit: bool = True,
153+
) -> dict[str, _XPSphinxCapability]: ...
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from array_api_compat import (
2+
common as common,
3+
device as device, # pyright: ignore[reportUnknownVariableType]
4+
get_namespace as get_namespace,
5+
is_array_api_obj as is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
6+
is_array_api_strict_namespace as is_array_api_strict_namespace,
7+
is_cupy_array as is_cupy_array,
8+
is_cupy_namespace as is_cupy_namespace,
9+
is_dask_array as is_dask_array, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_namespace as is_dask_namespace,
11+
is_jax_array as is_jax_array, # pyright: ignore[reportUnknownVariableType]
12+
is_jax_namespace as is_jax_namespace,
13+
is_lazy_array as is_lazy_array,
14+
is_ndonnx_array as is_ndonnx_array, # pyright: ignore[reportUnknownVariableType]
15+
is_ndonnx_namespace as is_ndonnx_namespace,
16+
is_numpy_array as is_numpy_array,
17+
is_numpy_namespace as is_numpy_namespace,
18+
is_pydata_sparse_array as is_pydata_sparse_array, # pyright: ignore[reportUnknownVariableType]
19+
is_pydata_sparse_namespace as is_pydata_sparse_namespace,
20+
is_torch_array as is_torch_array, # pyright: ignore[reportUnknownVariableType]
21+
is_torch_namespace as is_torch_namespace,
22+
is_writeable_array as is_writeable_array,
23+
size as size,
24+
to_device as to_device,
25+
)
26+
27+
from ._array_api import array_namespace as scipy_array_namespace
28+
29+
array_namespace = scipy_array_namespace

scipy-stubs/_lib/_ccallback_c.pyi

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from collections.abc import Callable
2+
from typing import Any, Final, SupportsFloat, TypedDict, type_check_only
3+
from typing_extensions import CapsuleType, ReadOnly, TypeIs
4+
5+
@type_check_only
6+
class _CApiDict(TypedDict):
7+
plus1_cython: ReadOnly[CapsuleType]
8+
plus1b_cython: ReadOnly[CapsuleType]
9+
plus1bc_cython: ReadOnly[CapsuleType]
10+
sine: ReadOnly[CapsuleType]
11+
12+
###
13+
14+
__pyx_capi__: Final[_CApiDict] = ... # undocumented
15+
__test__: Final[dict[Any, Any]] = ... # undocumented
16+
17+
def check_capsule(item: object) -> TypeIs[CapsuleType]: ... # undocumented
18+
def get_capsule_signature(capsule_obj: CapsuleType) -> str: ... # undocumented
19+
def get_raw_capsule(
20+
func_obj: CapsuleType | int, name_obj: str, context_obj: CapsuleType | int
21+
) -> CapsuleType: ... # undocumented
22+
23+
# namespace pollution
24+
idx: int = 1 # undocumented
25+
sig: tuple[bytes, int] = ... # undocumented
26+
sigs: list[tuple[bytes, int]] = ... # undocumented
27+
28+
def test_call_cython(callback_obj: Callable[[float], SupportsFloat], value: float) -> float: ... # undocumented

uv.lock

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)