Skip to content

Commit c491105

Browse files
authored
Update type annotations in compare.c and vector.c (#663)
- Add type annotations for functions in compare.c - Update annotations in vector.c to match behavior Fixes #662
1 parent acfec3e commit c491105

File tree

3 files changed

+153
-34
lines changed

3 files changed

+153
-34
lines changed

code/numpy/compare.c

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,23 @@ static mp_obj_t compare_equal_helper(mp_obj_t x1, mp_obj_t x2, uint8_t comptype)
140140
#endif
141141

142142
#if ULAB_NUMPY_HAS_CLIP
143-
143+
//| def clip(
144+
//| a: _ScalarOrArrayLike,
145+
//| a_min: _ScalarOrArrayLike,
146+
//| a_max: _ScalarOrArrayLike,
147+
//| ) -> _ScalarOrNdArray:
148+
//| """
149+
//| Clips (limits) the values in an array.
150+
//|
151+
//| :param a: Scalar or array containing elements to clip.
152+
//| :param a_min: Minimum value, it will be broadcast against ``a``.
153+
//| :param a_max: Maximum value, it will be broadcast against ``a``.
154+
//| :return:
155+
//| A scalar or array with the elements of ``a``, but where
156+
//| values < ``a_min`` are replaced with ``a_min``, and those
157+
//| > ``a_max`` with ``a_max``.
158+
//| """
159+
//| ...
144160
mp_obj_t compare_clip(mp_obj_t x1, mp_obj_t x2, mp_obj_t x3) {
145161
// Note: this function could be made faster by implementing a single-loop comparison in
146162
// RUN_COMPARE_LOOP. However, that would add around 2 kB of compile size, while we
@@ -166,7 +182,18 @@ MP_DEFINE_CONST_FUN_OBJ_3(compare_clip_obj, compare_clip);
166182
#endif
167183

168184
#if ULAB_NUMPY_HAS_EQUAL
169-
185+
//| def equal(x: _ScalarOrArrayLike, y: _ScalarOrArrayLike) -> _ScalarOrNdArray:
186+
//| """
187+
//| Returns ``x == y`` element-wise.
188+
//|
189+
//| :param x, y:
190+
//| Input scalar or array. If ``x.shape != y.shape`` they must
191+
//| be broadcastable to a common shape (which becomes the
192+
//| shape of the output.)
193+
//| :return:
194+
//| A boolean scalar or array with the element-wise result of ``x == y``.
195+
//| """
196+
//| ...
170197
mp_obj_t compare_equal(mp_obj_t x1, mp_obj_t x2) {
171198
return compare_equal_helper(x1, x2, COMPARE_EQUAL);
172199
}
@@ -175,7 +202,21 @@ MP_DEFINE_CONST_FUN_OBJ_2(compare_equal_obj, compare_equal);
175202
#endif
176203

177204
#if ULAB_NUMPY_HAS_NOTEQUAL
178-
205+
//| def not_equal(
206+
//| x: _ScalarOrArrayLike,
207+
//| y: _ScalarOrArrayLike,
208+
//| ) -> Union[_bool, ulab.numpy.ndarray]:
209+
//| """
210+
//| Returns ``x != y`` element-wise.
211+
//|
212+
//| :param x, y:
213+
//| Input scalar or array. If ``x.shape != y.shape`` they must
214+
//| be broadcastable to a common shape (which becomes the
215+
//| shape of the output.)
216+
//| :return:
217+
//| A boolean scalar or array with the element-wise result of ``x != y``.
218+
//| """
219+
//| ...
179220
mp_obj_t compare_not_equal(mp_obj_t x1, mp_obj_t x2) {
180221
return compare_equal_helper(x1, x2, COMPARE_NOT_EQUAL);
181222
}
@@ -270,6 +311,16 @@ static mp_obj_t compare_isinf_isfinite(mp_obj_t _x, uint8_t mask) {
270311
#endif
271312

272313
#if ULAB_NUMPY_HAS_ISFINITE
314+
//| def isfinite(x: _ScalarOrNdArray) -> Union[_bool, ulab.numpy.ndarray]:
315+
//| """
316+
//| Tests element-wise for finiteness (i.e., it should not be infinity or a NaN).
317+
//|
318+
//| :param x: Input scalar or ndarray.
319+
//| :return:
320+
//| A boolean scalar or array with True where ``x`` is finite, and
321+
//| False otherwise.
322+
//| """
323+
//| ...
273324
mp_obj_t compare_isfinite(mp_obj_t _x) {
274325
return compare_isinf_isfinite(_x, 0);
275326
}
@@ -278,6 +329,16 @@ MP_DEFINE_CONST_FUN_OBJ_1(compare_isfinite_obj, compare_isfinite);
278329
#endif
279330

280331
#if ULAB_NUMPY_HAS_ISINF
332+
//| def isinf(x: _ScalarOrNdArray) -> Union[_bool, ulab.numpy.ndarray]:
333+
//| """
334+
//| Tests element-wise for positive or negative infinity.
335+
//|
336+
//| :param x: Input scalar or ndarray.
337+
//| :return:
338+
//| A boolean scalar or array with True where ``x`` is positive or
339+
//| negative infinity, and False otherwise.
340+
//| """
341+
//| ...
281342
mp_obj_t compare_isinf(mp_obj_t _x) {
282343
return compare_isinf_isfinite(_x, 1);
283344
}
@@ -286,6 +347,18 @@ MP_DEFINE_CONST_FUN_OBJ_1(compare_isinf_obj, compare_isinf);
286347
#endif
287348

288349
#if ULAB_NUMPY_HAS_MAXIMUM
350+
//| def maximum(x1: _ScalarOrArrayLike, x2: _ScalarOrArrayLike) -> _ScalarOrNdArray:
351+
//| """
352+
//| Returns the element-wise maximum.
353+
//|
354+
//| :param x1, x2:
355+
//| Input scalar or array. If ``x.shape != y.shape`` they must
356+
//| be broadcastable to a common shape (which becomes the
357+
//| shape of the output.)
358+
//| :return:
359+
//| A scalar or array with the element-wise maximum of ``x1`` and ``x2``.
360+
//| """
361+
//| ...
289362
mp_obj_t compare_maximum(mp_obj_t x1, mp_obj_t x2) {
290363
// extra round, so that we can return maximum(3, 4) properly
291364
mp_obj_t result = compare_function(x1, x2, COMPARE_MAXIMUM);
@@ -301,6 +374,18 @@ MP_DEFINE_CONST_FUN_OBJ_2(compare_maximum_obj, compare_maximum);
301374

302375
#if ULAB_NUMPY_HAS_MINIMUM
303376

377+
//| def minimum(x1: _ScalarOrArrayLike, x2: _ScalarOrArrayLike) -> _ScalarOrNdArray:
378+
//| """
379+
//| Returns the element-wise minimum.
380+
//|
381+
//| :param x1, x2:
382+
//| Input scalar or array. If ``x.shape != y.shape`` they must
383+
//| be broadcastable to a common shape (which becomes the
384+
//| shape of the output.)
385+
//| :return:
386+
//| A scalar or array with the element-wise minimum of ``x1`` and ``x2``.
387+
//| """
388+
//| ...
304389
mp_obj_t compare_minimum(mp_obj_t x1, mp_obj_t x2) {
305390
// extra round, so that we can return minimum(3, 4) properly
306391
mp_obj_t result = compare_function(x1, x2, COMPARE_MINIMUM);
@@ -316,6 +401,17 @@ MP_DEFINE_CONST_FUN_OBJ_2(compare_minimum_obj, compare_minimum);
316401

317402
#if ULAB_NUMPY_HAS_NONZERO
318403

404+
//| def nonzero(x: _ScalarOrArrayLike) -> ulab.numpy.ndarray:
405+
//| """
406+
//| Returns the indices of elements that are non-zero.
407+
//|
408+
//| :param x:
409+
//| Input scalar or array. If ``x`` is a scalar, it is treated
410+
//| as a single-element 1-d array.
411+
//| :return:
412+
//| An array of indices that are non-zero.
413+
//| """
414+
//| ...
319415
mp_obj_t compare_nonzero(mp_obj_t x) {
320416
ndarray_obj_t *ndarray_x = ndarray_from_mp_obj(x, 0);
321417
// since ndarray_new_linear_array calls m_new0, the content of zero is a single zero
@@ -446,6 +542,27 @@ MP_DEFINE_CONST_FUN_OBJ_1(compare_nonzero_obj, compare_nonzero);
446542

447543
#if ULAB_NUMPY_HAS_WHERE
448544

545+
//| def where(
546+
//| condition: _ScalarOrArrayLike,
547+
//| x: _ScalarOrArrayLike,
548+
//| y: _ScalarOrArrayLike,
549+
//| ) -> ulab.numpy.ndarray:
550+
//| """
551+
//| Returns elements from ``x`` or ``y`` depending on ``condition``.
552+
//|
553+
//| :param condition:
554+
//| Input scalar or array. If an element (or scalar) is truthy,
555+
//| the corresponding element from ``x`` is chosen, otherwise
556+
//| ``y`` is used. ``condition``, ``x`` and ``y`` must also be
557+
//| broadcastable to the same shape (which becomes the output
558+
//| shape.)
559+
//| :param x, y:
560+
//| Input scalar or array.
561+
//| :return:
562+
//| An array with elements from ``x`` when ``condition`` is
563+
//| truthy, and ``y`` elsewhere.
564+
//| """
565+
//| ...
449566
mp_obj_t compare_where(mp_obj_t _condition, mp_obj_t _x, mp_obj_t _y) {
450567
// this implementation will work with ndarrays, and scalars only
451568
ndarray_obj_t *c = ndarray_from_mp_obj(_condition, 0);

code/numpy/numerical.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ enum NUMERICAL_FUNCTION_TYPE {
4545
//| from typing import Dict
4646
//|
4747
//| _ArrayLike = Union[ndarray, List[_float], Tuple[_float], range]
48+
//| _ScalarOrArrayLike = Union[int, _float, _ArrayLike]
49+
//| _ScalarOrNdArray = Union[int, _float, ndarray]
4850
//|
4951
//| _DType = int
5052
//| """`ulab.numpy.int8`, `ulab.numpy.uint8`, `ulab.numpy.int16`, `ulab.numpy.uint16`, `ulab.numpy.float` or `ulab.numpy.bool`"""

0 commit comments

Comments
 (0)