Skip to content

Commit 33eb269

Browse files
committed
🎨 linalg: simplified _matfuncs type aliases
1 parent 8fea2ac commit 33eb269

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

‎scipy-stubs/linalg/_matfuncs.pyi

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ _Complex128ND: TypeAlias = onp.ArrayND[np.complex128]
4141
_ComplexND: TypeAlias = onp.ArrayND[npc.complexfloating]
4242
_InexactND: TypeAlias = onp.ArrayND[npc.inexact]
4343

44+
_AsFloat64ND: TypeAlias = onp.ToArrayND[float, np.float64 | npc.integer | np.bool_]
45+
4446
###
4547

4648
eps: Final[np.float64] = ... # undocumented
@@ -77,7 +79,7 @@ def fractional_matrix_power(A: onp.ToComplexND, t: onp.ToFloat) -> onp.ArrayND[A
7779

7880
#
7981
@overload
80-
def sqrtm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
82+
def sqrtm(A: _AsFloat64ND) -> _Float64ND: ...
8183
@overload
8284
def sqrtm(A: onp.ToFloatND) -> _FloatND: ...
8385
@overload
@@ -106,7 +108,7 @@ def sqrtm(A: onp.ToComplexND, disp: onp.ToFalse, blocksize: int) -> tuple[_Inexa
106108

107109
# NOTE: return dtype depends on the sign of the values
108110
@overload # +integer | ~float64
109-
def logm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND | _Complex128ND: ...
111+
def logm(A: _AsFloat64ND) -> _Float64ND | _Complex128ND: ...
110112
@overload # +floating
111113
def logm(A: onp.ToFloatND) -> _InexactND: ...
112114
@overload # ~complex128
@@ -127,7 +129,7 @@ def logm(A: onp.ToComplexND, disp: onp.ToFalse) -> tuple[_InexactND, float]: ...
127129

128130
#
129131
@overload # +integer | ~float64
130-
def expm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
132+
def expm(A: _AsFloat64ND) -> _Float64ND: ...
131133
@overload # +floating
132134
def expm(A: onp.ToFloatND) -> _FloatND: ...
133135
@overload # ~complex128
@@ -142,7 +144,7 @@ def _exp_sinch(x: onp.ArrayND[_ComplexT, _ShapeT]) -> onp.ArrayND[_ComplexT, _Sh
142144

143145
#
144146
@overload # +integer | ~float64
145-
def cosm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
147+
def cosm(A: _AsFloat64ND) -> _Float64ND: ...
146148
@overload # +floating
147149
def cosm(A: onp.ToFloatND) -> _FloatND: ...
148150
@overload # ~complex128
@@ -154,7 +156,7 @@ def cosm(A: onp.ToComplexND) -> _InexactND: ...
154156

155157
#
156158
@overload # +integer | ~float64
157-
def sinm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
159+
def sinm(A: _AsFloat64ND) -> _Float64ND: ...
158160
@overload # +floating
159161
def sinm(A: onp.ToFloatND) -> _FloatND: ...
160162
@overload # ~complex128
@@ -166,7 +168,7 @@ def sinm(A: onp.ToComplexND) -> _InexactND: ...
166168

167169
#
168170
@overload # +integer | ~float64
169-
def tanm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
171+
def tanm(A: _AsFloat64ND) -> _Float64ND: ...
170172
@overload # +floating
171173
def tanm(A: onp.ToFloatND) -> _FloatND: ...
172174
@overload # ~complex128
@@ -178,7 +180,7 @@ def tanm(A: onp.ToComplexND) -> _InexactND: ...
178180

179181
#
180182
@overload # +integer | ~float64
181-
def coshm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
183+
def coshm(A: _AsFloat64ND) -> _Float64ND: ...
182184
@overload # +floating
183185
def coshm(A: onp.ToFloatND) -> _FloatND: ...
184186
@overload # ~complex128
@@ -190,7 +192,7 @@ def coshm(A: onp.ToComplexND) -> _InexactND: ...
190192

191193
#
192194
@overload # +integer | ~float64
193-
def sinhm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
195+
def sinhm(A: _AsFloat64ND) -> _Float64ND: ...
194196
@overload # +floating
195197
def sinhm(A: onp.ToFloatND) -> _FloatND: ...
196198
@overload # ~complex128
@@ -202,7 +204,7 @@ def sinhm(A: onp.ToComplexND) -> _InexactND: ...
202204

203205
#
204206
@overload # +integer | ~float64
205-
def tanhm(A: onp.ToIntND | onp.ToJustFloat64_ND) -> _Float64ND: ...
207+
def tanhm(A: _AsFloat64ND) -> _Float64ND: ...
206208
@overload # +floating
207209
def tanhm(A: onp.ToFloatND) -> _FloatND: ...
208210
@overload # ~complex128
@@ -228,7 +230,7 @@ def funm(A: onp.CanArrayND[_InexactT], func: _FuncND[_InexactT], disp: onp.ToFal
228230

229231
#
230232
@overload # +float64
231-
def signm(A: onp.ToIntND | onp.ToFloat64_ND) -> _Float64ND: ...
233+
def signm(A: _AsFloat64ND) -> _Float64ND: ...
232234
@overload # +floating
233235
def signm(A: onp.ToFloatND) -> _FloatND: ...
234236
@overload # +complexfloating
@@ -248,7 +250,7 @@ def signm(A: onp.ToComplexND, disp: onp.ToTrue) -> _InexactND: ...
248250
def signm(A: onp.ToComplexND, disp: onp.ToFalse) -> tuple[_InexactND, np.float64]: ...
249251

250252
#
251-
@overload # +integer, +integer
253+
@overload # +integer | ~float64
252254
def khatri_rao(a: onp.ToIntND, b: onp.ToIntND) -> _IntND: ...
253255
@overload # +float64, ~float64
254256
def khatri_rao(a: onp.ToFloat64_ND, b: onp.ToJustFloat64_ND) -> _Float64ND: ...

0 commit comments

Comments
 (0)