Skip to content

Commit 112d4f8

Browse files
HugoNumworksv923z
andauthored
Polyval handles non-array as second argument (#601)
* Factorize polynomial evaluation * Polyval handles non-array as second argument --------- Co-authored-by: Zoltán Vörös <[email protected]>
1 parent 319df10 commit 112d4f8

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

code/numpy/poly.c

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,18 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(poly_polyfit_obj, 2, 3, poly_polyfit);
145145

146146
#if ULAB_NUMPY_HAS_POLYVAL
147147

148+
static mp_float_t poly_eval(mp_float_t x, mp_float_t *p, uint8_t plen) {
149+
mp_float_t y = p[0];
150+
for(uint8_t j=0; j < plen-1; j++) {
151+
y *= x;
152+
y += p[j+1];
153+
}
154+
return y;
155+
}
156+
148157
mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
149-
if(!ndarray_object_is_array_like(o_p) || !ndarray_object_is_array_like(o_x)) {
150-
mp_raise_TypeError(translate("inputs are not iterable"));
158+
if(!ndarray_object_is_array_like(o_p)) {
159+
mp_raise_TypeError(translate("input is not iterable"));
151160
}
152161
#if ULAB_SUPPORTS_COMPLEX
153162
ndarray_obj_t *input;
@@ -171,6 +180,10 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
171180
i++;
172181
}
173182

183+
if(!ndarray_object_is_array_like(o_x)) {
184+
return mp_obj_new_float(poly_eval(mp_obj_get_float(o_x), p, plen));
185+
}
186+
174187
// polynomials are going to be of type float, except, when both
175188
// the coefficients and the independent variable are integers
176189
ndarray_obj_t *ndarray;
@@ -198,13 +211,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
198211
#endif
199212
size_t l = 0;
200213
do {
201-
mp_float_t y = p[0];
202-
mp_float_t _x = func(sarray);
203-
for(uint8_t m=0; m < plen-1; m++) {
204-
y *= _x;
205-
y += p[m+1];
206-
}
207-
*array++ = y;
214+
*array++ = poly_eval(func(sarray), p, plen);
208215
sarray += source->strides[ULAB_MAX_DIMS - 1];
209216
l++;
210217
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
@@ -233,13 +240,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
233240
mp_obj_iter_buf_t x_buf;
234241
mp_obj_t x_item, x_iterable = mp_getiter(o_x, &x_buf);
235242
while ((x_item = mp_iternext(x_iterable)) != MP_OBJ_STOP_ITERATION) {
236-
mp_float_t _x = mp_obj_get_float(x_item);
237-
mp_float_t y = p[0];
238-
for(uint8_t j=0; j < plen-1; j++) {
239-
y *= _x;
240-
y += p[j+1];
241-
}
242-
*array++ = y;
243+
*array++ = poly_eval(mp_obj_get_float(x_item), p, plen);
243244
}
244245
}
245246
m_del(mp_float_t, p, plen);

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.3.2
36+
#define ULAB_VERSION 6.3.3
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

0 commit comments

Comments
 (0)