@@ -145,9 +145,18 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(poly_polyfit_obj, 2, 3, poly_polyfit);
145
145
146
146
#if ULAB_NUMPY_HAS_POLYVAL
147
147
148
+ static mp_float_t poly_eval (mp_float_t x , mp_float_t * p , uint8_t plen ) {
149
+ mp_float_t y = p [0 ];
150
+ for (uint8_t j = 0 ; j < plen - 1 ; j ++ ) {
151
+ y *= x ;
152
+ y += p [j + 1 ];
153
+ }
154
+ return y ;
155
+ }
156
+
148
157
mp_obj_t poly_polyval (mp_obj_t o_p , mp_obj_t o_x ) {
149
- if (!ndarray_object_is_array_like (o_p ) || ! ndarray_object_is_array_like ( o_x ) ) {
150
- mp_raise_TypeError (translate ("inputs are not iterable" ));
158
+ if (!ndarray_object_is_array_like (o_p )) {
159
+ mp_raise_TypeError (translate ("input is not iterable" ));
151
160
}
152
161
#if ULAB_SUPPORTS_COMPLEX
153
162
ndarray_obj_t * input ;
@@ -171,6 +180,10 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
171
180
i ++ ;
172
181
}
173
182
183
+ if (!ndarray_object_is_array_like (o_x )) {
184
+ return mp_obj_new_float (poly_eval (mp_obj_get_float (o_x ), p , plen ));
185
+ }
186
+
174
187
// polynomials are going to be of type float, except, when both
175
188
// the coefficients and the independent variable are integers
176
189
ndarray_obj_t * ndarray ;
@@ -198,13 +211,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
198
211
#endif
199
212
size_t l = 0 ;
200
213
do {
201
- mp_float_t y = p [0 ];
202
- mp_float_t _x = func (sarray );
203
- for (uint8_t m = 0 ; m < plen - 1 ; m ++ ) {
204
- y *= _x ;
205
- y += p [m + 1 ];
206
- }
207
- * array ++ = y ;
214
+ * array ++ = poly_eval (func (sarray ), p , plen );
208
215
sarray += source -> strides [ULAB_MAX_DIMS - 1 ];
209
216
l ++ ;
210
217
} while (l < source -> shape [ULAB_MAX_DIMS - 1 ]);
@@ -233,13 +240,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
233
240
mp_obj_iter_buf_t x_buf ;
234
241
mp_obj_t x_item , x_iterable = mp_getiter (o_x , & x_buf );
235
242
while ((x_item = mp_iternext (x_iterable )) != MP_OBJ_STOP_ITERATION ) {
236
- mp_float_t _x = mp_obj_get_float (x_item );
237
- mp_float_t y = p [0 ];
238
- for (uint8_t j = 0 ; j < plen - 1 ; j ++ ) {
239
- y *= _x ;
240
- y += p [j + 1 ];
241
- }
242
- * array ++ = y ;
243
+ * array ++ = poly_eval (mp_obj_get_float (x_item ), p , plen );
243
244
}
244
245
}
245
246
m_del (mp_float_t , p , plen );
0 commit comments