@@ -857,6 +857,194 @@ mp_obj_t ndarray_binary_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
857
857
}
858
858
#endif /* NDARRAY_HAS_BINARY_OP_POWER */
859
859
860
+ #if NDARRAY_HAS_BINARY_OP_OR | NDARRAY_HAS_BINARY_OP_XOR | NDARRAY_HAS_BINARY_OP_AND
861
+ mp_obj_t ndarray_binary_logical (ndarray_obj_t * lhs , ndarray_obj_t * rhs ,
862
+ uint8_t ndim , size_t * shape , int32_t * lstrides , int32_t * rstrides , mp_binary_op_t op ) {
863
+
864
+ #if ULAB_SUPPORTS_COMPLEX
865
+ if ((lhs -> dtype == NDARRAY_COMPLEX ) || (rhs -> dtype == NDARRAY_COMPLEX ) || (lhs -> dtype == NDARRAY_FLOAT ) || (rhs -> dtype == NDARRAY_FLOAT )) {
866
+ mp_raise_TypeError (translate ("operation not supported for the input types" ));
867
+ }
868
+ #else
869
+ if ((lhs -> dtype == NDARRAY_FLOAT ) || (rhs -> dtype == NDARRAY_FLOAT )) {
870
+ mp_raise_TypeError (translate ("operation not supported for the input types" ));
871
+ }
872
+ #endif
873
+
874
+ // bail out, if both inputs are of 16-bit types, but differ in sign;
875
+ // numpy promotes the result to int32
876
+ if (((lhs -> dtype == NDARRAY_INT16 ) && (rhs -> dtype == NDARRAY_UINT16 )) ||
877
+ ((lhs -> dtype == NDARRAY_UINT16 ) && (rhs -> dtype == NDARRAY_INT16 ))) {
878
+ mp_raise_TypeError (translate ("dtype of int32 is not supported" ));
879
+ }
880
+
881
+ ndarray_obj_t * results = NULL ;
882
+ uint8_t * larray = (uint8_t * )lhs -> array ;
883
+ uint8_t * rarray = (uint8_t * )rhs -> array ;
884
+
885
+
886
+ switch (op ) {
887
+ case MP_BINARY_OP_XOR :
888
+ if (lhs -> dtype == NDARRAY_UINT8 ) {
889
+ if (rhs -> dtype == NDARRAY_UINT8 ) {
890
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT8 );
891
+ if (lhs -> boolean & rhs -> boolean ) {
892
+ results -> boolean = 1 ;
893
+ }
894
+ BINARY_LOOP (results , uint8_t , uint8_t , uint8_t , larray , lstrides , rarray , rstrides , ^);
895
+ } else if (rhs -> dtype == NDARRAY_INT8 ) {
896
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
897
+ BINARY_LOOP (results , int16_t , uint8_t , int8_t , larray , lstrides , rarray , rstrides , ^);
898
+ } else if (rhs -> dtype == NDARRAY_UINT16 ) {
899
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT16 );
900
+ BINARY_LOOP (results , uint16_t , uint8_t , uint16_t , larray , lstrides , rarray , rstrides , ^);
901
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
902
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
903
+ BINARY_LOOP (results , int16_t , uint8_t , int16_t , larray , lstrides , rarray , rstrides , ^);
904
+ }
905
+ } else if (lhs -> dtype == NDARRAY_INT8 ) {
906
+ if (rhs -> dtype == NDARRAY_INT8 ) {
907
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT8 );
908
+ BINARY_LOOP (results , int8_t , int8_t , int8_t , larray , lstrides , rarray , rstrides , ^);
909
+ } else if (rhs -> dtype == NDARRAY_UINT16 ) {
910
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
911
+ BINARY_LOOP (results , int16_t , int8_t , uint16_t , larray , lstrides , rarray , rstrides , ^);
912
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
913
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
914
+ BINARY_LOOP (results , int16_t , int8_t , int16_t , larray , lstrides , rarray , rstrides , ^);
915
+ } else {
916
+ return ndarray_binary_op (MP_BINARY_OP_XOR , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
917
+ }
918
+ } else if (lhs -> dtype == NDARRAY_UINT16 ) {
919
+ if (rhs -> dtype == NDARRAY_UINT16 ) {
920
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT16 );
921
+ BINARY_LOOP (results , uint16_t , uint16_t , uint16_t , larray , lstrides , rarray , rstrides , ^);
922
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
923
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_FLOAT );
924
+ BINARY_LOOP (results , mp_float_t , uint16_t , int16_t , larray , lstrides , rarray , rstrides , ^);
925
+ } else {
926
+ return ndarray_binary_op (MP_BINARY_OP_XOR , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
927
+ }
928
+ } else if (lhs -> dtype == NDARRAY_INT16 ) {
929
+ if (rhs -> dtype == NDARRAY_INT16 ) {
930
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
931
+ BINARY_LOOP (results , int16_t , int16_t , int16_t , larray , lstrides , rarray , rstrides , ^);
932
+ } else {
933
+ return ndarray_binary_op (MP_BINARY_OP_XOR , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
934
+ }
935
+ }
936
+ break ;
937
+
938
+ case MP_BINARY_OP_OR :
939
+ if (lhs -> dtype == NDARRAY_UINT8 ) {
940
+ if (rhs -> dtype == NDARRAY_UINT8 ) {
941
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT8 );
942
+ if (lhs -> boolean & rhs -> boolean ) {
943
+ results -> boolean = 1 ;
944
+ }
945
+ BINARY_LOOP (results , uint8_t , uint8_t , uint8_t , larray , lstrides , rarray , rstrides , |);
946
+ } else if (rhs -> dtype == NDARRAY_INT8 ) {
947
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
948
+ BINARY_LOOP (results , int16_t , uint8_t , int8_t , larray , lstrides , rarray , rstrides , |);
949
+ } else if (rhs -> dtype == NDARRAY_UINT16 ) {
950
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT16 );
951
+ BINARY_LOOP (results , uint16_t , uint8_t , uint16_t , larray , lstrides , rarray , rstrides , |);
952
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
953
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
954
+ BINARY_LOOP (results , int16_t , uint8_t , int16_t , larray , lstrides , rarray , rstrides , |);
955
+ }
956
+ } else if (lhs -> dtype == NDARRAY_INT8 ) {
957
+ if (rhs -> dtype == NDARRAY_INT8 ) {
958
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT8 );
959
+ BINARY_LOOP (results , int8_t , int8_t , int8_t , larray , lstrides , rarray , rstrides , |);
960
+ } else if (rhs -> dtype == NDARRAY_UINT16 ) {
961
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
962
+ BINARY_LOOP (results , int16_t , int8_t , uint16_t , larray , lstrides , rarray , rstrides , |);
963
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
964
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
965
+ BINARY_LOOP (results , int16_t , int8_t , int16_t , larray , lstrides , rarray , rstrides , |);
966
+ } else {
967
+ return ndarray_binary_op (MP_BINARY_OP_OR , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
968
+ }
969
+ } else if (lhs -> dtype == NDARRAY_UINT16 ) {
970
+ if (rhs -> dtype == NDARRAY_UINT16 ) {
971
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT16 );
972
+ BINARY_LOOP (results , uint16_t , uint16_t , uint16_t , larray , lstrides , rarray , rstrides , |);
973
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
974
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_FLOAT );
975
+ BINARY_LOOP (results , mp_float_t , uint16_t , int16_t , larray , lstrides , rarray , rstrides , |);
976
+ } else {
977
+ return ndarray_binary_op (MP_BINARY_OP_OR , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
978
+ }
979
+ } else if (lhs -> dtype == NDARRAY_INT16 ) {
980
+ if (rhs -> dtype == NDARRAY_INT16 ) {
981
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
982
+ BINARY_LOOP (results , int16_t , int16_t , int16_t , larray , lstrides , rarray , rstrides , |);
983
+ } else {
984
+ return ndarray_binary_op (MP_BINARY_OP_OR , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
985
+ }
986
+ }
987
+ break ;
988
+
989
+ case MP_BINARY_OP_AND :
990
+ if (lhs -> dtype == NDARRAY_UINT8 ) {
991
+ if (rhs -> dtype == NDARRAY_UINT8 ) {
992
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT8 );
993
+ if (lhs -> boolean & rhs -> boolean ) {
994
+ results -> boolean = 1 ;
995
+ }
996
+ BINARY_LOOP (results , uint8_t , uint8_t , uint8_t , larray , lstrides , rarray , rstrides , & );
997
+ } else if (rhs -> dtype == NDARRAY_INT8 ) {
998
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
999
+ BINARY_LOOP (results , int16_t , uint8_t , int8_t , larray , lstrides , rarray , rstrides , & );
1000
+ } else if (rhs -> dtype == NDARRAY_UINT16 ) {
1001
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT16 );
1002
+ BINARY_LOOP (results , uint16_t , uint8_t , uint16_t , larray , lstrides , rarray , rstrides , & );
1003
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
1004
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
1005
+ BINARY_LOOP (results , int16_t , uint8_t , int16_t , larray , lstrides , rarray , rstrides , & );
1006
+ }
1007
+ } else if (lhs -> dtype == NDARRAY_INT8 ) {
1008
+ if (rhs -> dtype == NDARRAY_INT8 ) {
1009
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT8 );
1010
+ BINARY_LOOP (results , int8_t , int8_t , int8_t , larray , lstrides , rarray , rstrides , & );
1011
+ } else if (rhs -> dtype == NDARRAY_UINT16 ) {
1012
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
1013
+ BINARY_LOOP (results , int16_t , int8_t , uint16_t , larray , lstrides , rarray , rstrides , & );
1014
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
1015
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
1016
+ BINARY_LOOP (results , int16_t , int8_t , int16_t , larray , lstrides , rarray , rstrides , & );
1017
+ } else {
1018
+ return ndarray_binary_op (MP_BINARY_OP_AND , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
1019
+ }
1020
+ } else if (lhs -> dtype == NDARRAY_UINT16 ) {
1021
+ if (rhs -> dtype == NDARRAY_UINT16 ) {
1022
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_UINT16 );
1023
+ BINARY_LOOP (results , uint16_t , uint16_t , uint16_t , larray , lstrides , rarray , rstrides , & );
1024
+ } else if (rhs -> dtype == NDARRAY_INT16 ) {
1025
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_FLOAT );
1026
+ BINARY_LOOP (results , mp_float_t , uint16_t , int16_t , larray , lstrides , rarray , rstrides , & );
1027
+ } else {
1028
+ return ndarray_binary_op (MP_BINARY_OP_AND , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
1029
+ }
1030
+ } else if (lhs -> dtype == NDARRAY_INT16 ) {
1031
+ if (rhs -> dtype == NDARRAY_INT16 ) {
1032
+ results = ndarray_new_dense_ndarray (ndim , shape , NDARRAY_INT16 );
1033
+ BINARY_LOOP (results , int16_t , int16_t , int16_t , larray , lstrides , rarray , rstrides , & );
1034
+ } else {
1035
+ return ndarray_binary_op (MP_BINARY_OP_AND , MP_OBJ_FROM_PTR (rhs ), MP_OBJ_FROM_PTR (lhs ));
1036
+ }
1037
+ }
1038
+ break ;
1039
+ default :
1040
+ return MP_OBJ_NULL ; // op not supported
1041
+ break ;
1042
+ }
1043
+ return MP_OBJ_FROM_PTR (results );
1044
+ }
1045
+
1046
+ #endif /* NDARRAY_HAS_BINARY_OP_OR | NDARRAY_HAS_BINARY_OP_XOR | NDARRAY_HAS_BINARY_OP_AND */
1047
+
860
1048
#if NDARRAY_HAS_INPLACE_ADD || NDARRAY_HAS_INPLACE_MULTIPLY || NDARRAY_HAS_INPLACE_SUBTRACT
861
1049
mp_obj_t ndarray_inplace_ams (ndarray_obj_t * lhs , ndarray_obj_t * rhs , int32_t * rstrides , uint8_t optype ) {
862
1050
0 commit comments