Skip to content

Commit 47ad73a

Browse files
v923zmatemaciek
andauthored
Floordiv (#593)
* implement floor division * fix 3D, 4D loops * add missing array declaration in 3D, and 4D * Add test cases for floor division and fix it for ints (#599) * Add test cases for floor division * Fix define name in comment * Fix floor division of ints --------- Co-authored-by: Maciej Sokołowski <[email protected]>
1 parent 4407f8c commit 47ad73a

File tree

6 files changed

+392
-2
lines changed

6 files changed

+392
-2
lines changed

code/ndarray.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,6 +1936,12 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
19361936
return ndarray_binary_power(lhs, rhs, ndim, shape, lstrides, rstrides);
19371937
break;
19381938
#endif
1939+
#if NDARRAY_HAS_BINARY_OP_FLOOR_DIVIDE
1940+
case MP_BINARY_OP_FLOOR_DIVIDE:
1941+
COMPLEX_DTYPE_NOT_IMPLEMENTED(lhs->dtype);
1942+
return ndarray_binary_floor_divide(lhs, rhs, ndim, shape, lstrides, rstrides);
1943+
break;
1944+
#endif
19391945
default:
19401946
return MP_OBJ_NULL; // op not supported
19411947
break;

code/ndarray_operators.c

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,102 @@ mp_obj_t ndarray_binary_true_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
673673
}
674674
#endif /* NDARRAY_HAS_BINARY_OP_TRUE_DIVIDE */
675675

676+
#if NDARRAY_HAS_BINARY_OP_FLOOR_DIVIDE
677+
mp_obj_t ndarray_binary_floor_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
678+
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
679+
680+
ndarray_obj_t *results = NULL;
681+
uint8_t *larray = (uint8_t *)lhs->array;
682+
uint8_t *rarray = (uint8_t *)rhs->array;
683+
684+
if(lhs->dtype == NDARRAY_UINT8) {
685+
if(rhs->dtype == NDARRAY_UINT8) {
686+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
687+
FLOOR_DIVIDE_LOOP_UINT(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides);
688+
} else if(rhs->dtype == NDARRAY_INT8) {
689+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
690+
FLOOR_DIVIDE_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides);
691+
} else if(rhs->dtype == NDARRAY_UINT16) {
692+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
693+
FLOOR_DIVIDE_LOOP_UINT(results, uint16_t, uint8_t, uint16_t, larray, lstrides, rarray, rstrides);
694+
} else if(rhs->dtype == NDARRAY_INT16) {
695+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
696+
FLOOR_DIVIDE_LOOP(results, int16_t, uint8_t, int16_t, larray, lstrides, rarray, rstrides);
697+
} else if(rhs->dtype == NDARRAY_FLOAT) {
698+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
699+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint8_t, mp_float_t, larray, lstrides, rarray, rstrides);
700+
}
701+
} else if(lhs->dtype == NDARRAY_INT8) {
702+
if(rhs->dtype == NDARRAY_UINT8) {
703+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
704+
FLOOR_DIVIDE_LOOP(results, int16_t, int8_t, uint8_t, larray, lstrides, rarray, rstrides);
705+
} else if(rhs->dtype == NDARRAY_INT8) {
706+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT8);
707+
FLOOR_DIVIDE_LOOP(results, int8_t, int8_t, int8_t, larray, lstrides, rarray, rstrides);
708+
} else if(rhs->dtype == NDARRAY_UINT16) {
709+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
710+
FLOOR_DIVIDE_LOOP(results, uint16_t, int8_t, uint16_t, larray, lstrides, rarray, rstrides);
711+
} else if(rhs->dtype == NDARRAY_INT16) {
712+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
713+
FLOOR_DIVIDE_LOOP(results, int16_t, int8_t, int16_t, larray, lstrides, rarray, rstrides);
714+
} else if(rhs->dtype == NDARRAY_FLOAT) {
715+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
716+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, int8_t, mp_float_t, larray, lstrides, rarray, rstrides);
717+
}
718+
} else if(lhs->dtype == NDARRAY_UINT16) {
719+
if(rhs->dtype == NDARRAY_UINT8) {
720+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
721+
FLOOR_DIVIDE_LOOP_UINT(results, uint16_t, uint16_t, uint8_t, larray, lstrides, rarray, rstrides);
722+
} else if(rhs->dtype == NDARRAY_INT8) {
723+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
724+
FLOOR_DIVIDE_LOOP(results, uint16_t, uint16_t, int8_t, larray, lstrides, rarray, rstrides);
725+
} else if(rhs->dtype == NDARRAY_UINT16) {
726+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
727+
FLOOR_DIVIDE_LOOP_UINT(results, uint16_t, uint16_t, uint16_t, larray, lstrides, rarray, rstrides);
728+
} else if(rhs->dtype == NDARRAY_INT16) {
729+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
730+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint16_t, int16_t, larray, lstrides, rarray, rstrides);
731+
} else if(rhs->dtype == NDARRAY_FLOAT) {
732+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
733+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint16_t, mp_float_t, larray, lstrides, rarray, rstrides);
734+
}
735+
} else if(lhs->dtype == NDARRAY_INT16) {
736+
if(rhs->dtype == NDARRAY_UINT8) {
737+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
738+
FLOOR_DIVIDE_LOOP(results, int16_t, int16_t, uint8_t, larray, lstrides, rarray, rstrides);
739+
} else if(rhs->dtype == NDARRAY_INT8) {
740+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
741+
FLOOR_DIVIDE_LOOP(results, int16_t, int16_t, int8_t, larray, lstrides, rarray, rstrides);
742+
} else if(rhs->dtype == NDARRAY_UINT16) {
743+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
744+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, int16_t, uint16_t, larray, lstrides, rarray, rstrides);
745+
} else if(rhs->dtype == NDARRAY_INT16) {
746+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
747+
FLOOR_DIVIDE_LOOP(results, int16_t, int16_t, int16_t, larray, lstrides, rarray, rstrides);
748+
} else if(rhs->dtype == NDARRAY_FLOAT) {
749+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
750+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, uint16_t, mp_float_t, larray, lstrides, rarray, rstrides);
751+
}
752+
} else if(lhs->dtype == NDARRAY_FLOAT) {
753+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
754+
if(rhs->dtype == NDARRAY_UINT8) {
755+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, uint8_t, larray, lstrides, rarray, rstrides);
756+
} else if(rhs->dtype == NDARRAY_INT8) {
757+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, int8_t, larray, lstrides, rarray, rstrides);
758+
} else if(rhs->dtype == NDARRAY_UINT16) {
759+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, uint16_t, larray, lstrides, rarray, rstrides);
760+
} else if(rhs->dtype == NDARRAY_INT16) {
761+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, int16_t, larray, lstrides, rarray, rstrides);
762+
} else if(rhs->dtype == NDARRAY_FLOAT) {
763+
FLOOR_DIVIDE_LOOP_FLOAT(results, mp_float_t, mp_float_t, mp_float_t, larray, lstrides, rarray, rstrides);
764+
}
765+
}
766+
767+
return MP_OBJ_FROM_PTR(results);
768+
769+
}
770+
#endif /* NDARRAY_HAS_BINARY_OP_FLOOR_DIVIDE */
771+
676772
#if NDARRAY_HAS_BINARY_OP_POWER
677773
mp_obj_t ndarray_binary_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
678774
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
@@ -812,7 +908,7 @@ mp_obj_t ndarray_inplace_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t
812908
}
813909
return MP_OBJ_FROM_PTR(lhs);
814910
}
815-
#endif /* NDARRAY_HAS_INPLACE_DIVIDE */
911+
#endif /* NDARRAY_HAS_INPLACE_TRUE_DIVIDE */
816912

817913
#if NDARRAY_HAS_INPLACE_POWER
818914
mp_obj_t ndarray_inplace_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rstrides) {

0 commit comments

Comments
 (0)