Skip to content

Commit 481e392

Browse files
authored
🐛 optimize.least_squares: Allow scalar floating point return type in residual function
1 parent 05cbc31 commit 481e392

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

scipy-stubs/optimize/_lsq/least_squares.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ _XScale: TypeAlias = onp.ToFloat | onp.ToFloatND | _XScaleMethod
3030
_LossMethod: TypeAlias = Literal["linear", "soft_l1", "huber", "cauchy", "arctan"]
3131
_Loss: TypeAlias = _UserLossFunction | _LossMethod
3232

33-
_ResidFunction: TypeAlias = Callable[Concatenate[_Float1D, ...], onp.ToFloat1D]
33+
_ResidFunction: TypeAlias = Callable[Concatenate[_Float1D, ...], onp.ToFloat1D | onp.ToFloat]
3434

3535
_ResultStatus: TypeAlias = Literal[-2, -1, 0, 1, 2, 3, 4]
3636

tests/optimize/test_least_squares.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,8 @@ x0: npt.NDArray[np.float64]
1515

1616
result = least_squares(example_residual, x0=x0)
1717
assert_type(result.x, np.ndarray[tuple[int], np.dtype[np.float64]])
18+
19+
def example_residual_scalar(x: npt.NDArray[np.float64]) -> np.float64 | float: ...
20+
21+
result = least_squares(example_residual_scalar, x0=x0)
22+
assert_type(result.x, np.ndarray[tuple[int], np.dtype[np.float64]])

0 commit comments

Comments
 (0)