Skip to content
This repository was archived by the owner on Feb 17, 2021. It is now read-only.

Commit e4f26ee

Browse files
committed
Improve scalar + array, and add complex support
1 parent 3fb733e commit e4f26ee

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

numpy-stubs/__init__.pyi

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,17 @@ class void:
4242
# and this is why we let int32 be a subclass of int64; and similarly for float32 and float64
4343
# the same logic applies when adding unsigned and signed values (uint + int -> int)
4444

45+
class complex128(void, int):
46+
def __complex__(self) -> complex: ...
47+
48+
class complex64(complex128): ...
49+
4550
# this would be the correct definition, but it makes `int` conflict with `float`
4651
# class float64(void, float): ...
47-
class float64(void, int):
52+
class float64(complex128):
4853
def __float__(self) -> float: ...
4954

50-
class float32(float64): ...
55+
class float32(float64, complex64): ...
5156
class float16(float32): ...
5257

5358
floating = float64
@@ -70,6 +75,8 @@ integer = int64
7075
_DType = TypeVar(
7176
"_DType",
7277
bool_,
78+
complex64,
79+
complex128,
7380
float16,
7481
float32,
7582
float64,
@@ -87,6 +94,8 @@ _DType = TypeVar(
8794
_DType2 = TypeVar(
8895
"_DType2",
8996
bool_,
97+
complex64,
98+
complex128,
9099
float16,
91100
float32,
92101
float64,
@@ -110,7 +119,7 @@ _ScalarLike = Union[_DType, str, int, float]
110119
_ConditionType = Union[ndarray[bool_], bool_, bool]
111120
newaxis: None = ...
112121

113-
_AnyNum = Union[int, float, bool]
122+
_AnyNum = Union[int, float, bool, complex]
114123
# generic types that are only allowed to take on dtype values
115124

116125
_Float = TypeVar("_Float", float16, float32, float64)
@@ -391,7 +400,7 @@ class ndarray(Generic[_DType]):
391400
@overload
392401
def __radd__(self, value: _DType) -> ndarray[_DType]: ...
393402
@overload
394-
def __radd__(self, value: float) -> ndarray[_DType]: ...
403+
def __radd__(self, value: float) -> ndarray[_DType2]: ...
395404
def __rand__(self, value: object) -> ndarray[_DType]: ...
396405
def __rdivmod__(self, value: object) -> Tuple[ndarray[_DType], ndarray[_DType]]: ...
397406
def __rfloordiv__(self, value: object) -> ndarray[_DType]: ...
@@ -506,6 +515,8 @@ def array(object: _NestedList[int]) -> ndarray[int64]: ...
506515
@overload
507516
def array(object: _NestedList[float]) -> ndarray[float64]: ...
508517
@overload
518+
def array(object: _NestedList[complex]) -> ndarray[complex64]: ...
519+
@overload
509520
def array(object: _NestedList[str]) -> ndarray[str_]: ...
510521
@overload
511522
def array(object: str) -> ndarray[str_]: ...

tests/numpy_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
DType = TypeVar(
99
"DType",
1010
np.bool_,
11+
np.complex64,
12+
np.complex128,
1113
np.float32,
1214
np.float64,
1315
np.int8,
@@ -306,6 +308,10 @@ def test_newaxis() -> None:
306308

307309

308310
def test_sum_scalar_before() -> None:
309-
x = 273.15 + np.array([-0.1e2, -0.77e1])
311+
x: np.ndarray[np.float64] = 273.15 + np.array([10, 20])
310312
assert isinstance(x, np.ndarray)
311-
assert x.dtype == np.float64
313+
assert_dtype(x, np.float64)
314+
315+
y: np.ndarray[np.complex128] = 10.0 + np.array([1j, 2j])
316+
assert isinstance(y, np.ndarray)
317+
assert_dtype(y, np.complex128)

0 commit comments

Comments
 (0)