From aa24f9d199dbd797db24f27b31438add2fbfe801 Mon Sep 17 00:00:00 2001 From: Yen-Fu Chen Date: Sat, 8 Jun 2024 00:02:20 +0800 Subject: [PATCH] jit: Fix overflow handling in function muldivmod The behavior of the host instructions div and mod differs from that of RISC-V. Additional checks are required to align with RISC-V's DIV[U] and REM[U] behavior, particularly when handling division by zero and overflow scenarios. Close: #297 --- src/jit.c | 165 ++++++++++++++++++++++++++------------------ src/rv32_template.c | 4 +- 2 files changed, 101 insertions(+), 68 deletions(-) diff --git a/src/jit.c b/src/jit.c index 51bf8bc7..62f1349c 100644 --- a/src/jit.c +++ b/src/jit.c @@ -762,24 +762,6 @@ static inline void emit_alu64_imm32(struct jit_state *state, emit_alu64(state, op, src, dst); emit4(state, imm); } -#elif defined(__aarch64__) -static void divmod(struct jit_state *state, - uint8_t opcode, - int rd, - int rn, - int rm) -{ - bool mod = (opcode & JIT_ALU_OP_MASK) == (JIT_OP_MOD_IMM & JIT_ALU_OP_MASK); - bool is64 = (opcode & JIT_CLS_MASK) == JIT_CLS_ALU64; - int div_dest = mod ? temp_div_reg : rd; - - /* Do not need to treet divide by zero as special because the UDIV - * instruction already returns 0 when dividing by zero. - */ - emit_dataproc_2source(state, is64, DP2_UDIV, div_dest, rn, rm); - if (mod) - emit_dataproc_3source(state, is64, DP3_MSUB, rd, rm, div_dest, rn); -} #endif static inline void emit_cmp_imm32(struct jit_state *state, int dst, int32_t imm) @@ -1035,33 +1017,78 @@ static inline void emit_exit(struct jit_state *state) #endif } -/* TODO: muldivmod is incomplete, it does not handle imm or overflow now */ #if RV32_HAS(EXT_M) +#if defined(__x86_64__) +static inline void emit_conditional_move(struct jit_state *state, + int src, + int dst) +{ + emit1(state, 0x48); + emit1(state, 0x0f); + emit1(state, 0x44); + emit_modrm_reg2reg(state, dst, src); +} +#elif defined(__aarch64__) +static inline void emit_conditional_move(struct jit_state *state, + int rd, + int rn, + int rm, + int cond) +{ + emit_a64(state, 0x1a800000 | (rm << 16) | (cond << 12) | (rn << 5) | rd); + set_dirty(rd, true); +} + +static void divmod(struct jit_state *state, + uint8_t opcode, + int rd, + int rn, + int rm, + bool sign) +{ + bool mod = (opcode & JIT_ALU_OP_MASK) == (JIT_OP_MOD_IMM & JIT_ALU_OP_MASK); + bool is64 = (opcode & JIT_CLS_MASK) == JIT_CLS_ALU64; + int div_dest = mod ? temp_div_reg : rd; + + if (sign) + emit_cmp_imm32(state, rd, 0x80000000); /* overflow checking */ + + emit_dataproc_2source(state, is64, DP2_UDIV, div_dest, rn, rm); + if (mod) + emit_dataproc_3source(state, is64, DP3_MSUB, rd, rm, div_dest, rn); + + if (sign) { + /* handle overflow */ + uint32_t jump_loc = state->offset; + emit_jcc_offset(state, 0x85); + emit_cmp_imm32(state, rm, -1); + if (mod) + emit_load_imm(state, R10, 0); + else + emit_load_imm(state, R10, 0x80000000); + emit_conditional_move(state, rd, R10, rd, COND_EQ); + emit_jump_target_offset(state, JUMP_LOC, state->offset); + } + if (!mod) { + /* handle dividing zero */ + emit_cmp_imm32(state, rm, 0); + emit_load_imm(state, temp_reg, -1); + emit_conditional_move(state, rd, temp_reg, rd, COND_EQ); + } +} +#endif + static void muldivmod(struct jit_state *state, uint8_t opcode, int src, int dst, - int32_t imm UNUSED) + bool sign) { #if defined(__x86_64__) bool mul = (opcode & JIT_ALU_OP_MASK) == (JIT_OP_MUL_IMM & JIT_ALU_OP_MASK); bool div = (opcode & JIT_ALU_OP_MASK) == (JIT_OP_DIV_IMM & JIT_ALU_OP_MASK); bool mod = (opcode & JIT_ALU_OP_MASK) == (JIT_OP_MOD_IMM & JIT_ALU_OP_MASK); bool is64 = (opcode & JIT_CLS_MASK) == JIT_CLS_ALU64; - bool reg = (opcode & JIT_SRC_REG) == JIT_SRC_REG; - - /* Short circuit for imm == 0 */ - if (!reg && imm == 0) { - assert(NULL); - if (div || mul) { - /* For division and multiplication, set result to zero. */ - emit_alu32(state, 0x31, dst, dst); - } else { - /* For modulo, set result to dividend. */ - emit_mov(state, dst, dst); - } - return; - } /* Record the mapping status before the registers are used for other * purposes, and restore the status after popping the registers. @@ -1080,51 +1107,44 @@ static void muldivmod(struct jit_state *state, } /* Load the divisor into RCX */ - if (imm) - emit_load_imm(state, RCX, imm); - else - emit_mov(state, src, RCX); + emit_mov(state, src, RCX); /* Load the dividend into RAX */ emit_mov(state, dst, RAX); /* The JIT employs two different semantics for division and modulus * operations. In the case of division, if the divisor is zero, the result - * is set to zero. For modulus operations, if the divisor is zero, the + * is set to -1. For modulus operations, if the divisor is zero, the * result becomes the dividend. To manage this, we first set the divisor to * 1 if it is initially zero. Then, we adjust the result accordingly: for - * division, we set it to zero if the original divisor was zero; for + * division, we set it to -1 if the original divisor was zero; for * modulus, we set it to the dividend under the same condition. */ if (div || mod) { - /* Check if divisor is zero */ - if (is64) - emit_alu64(state, 0x85, RCX, RCX); - else - emit_alu32(state, 0x85, RCX, RCX); - - /* Save the dividend for the modulo case */ - if (mod) + if (sign) { + emit_load_imm(state, RDX, -1); + /* compare divisor with -1 for overflow checking */ + emit_cmp32(state, RDX, RCX); + /* Save the result of the comparision */ + emit1(state, 0x9c); /* pushfq */ + } + if (mod || (div && sign)) emit_push(state, RAX); /* Save dividend */ + emit_alu32(state, 0x85, RCX, RCX); /* Save the result of the test */ emit1(state, 0x9c); /* pushfq */ /* Set the divisor to 1 if it is zero */ emit_load_imm(state, RDX, 1); - emit1(state, 0x48); - emit1(state, 0x0f); - emit1(state, 0x44); - emit1(state, 0xca); /* cmove rcx, rdx */ - + emit_conditional_move(state, RDX, RCX); /* xor %edx,%edx */ emit_alu32(state, 0x31, RDX, RDX); } if (is64) emit_rex(state, 1, 0, 0, 0); - /* Multiply or divide */ emit_alu32(state, 0xf7, mul ? 4 : 6, RCX); @@ -1139,24 +1159,37 @@ static void muldivmod(struct jit_state *state, if (div) { /* Set the dividend to zero if the divisor was zero. */ - emit_load_imm(state, RCX, 0); + emit_load_imm(state, RCX, -1); /* Store 0 in RAX if the divisor was zero. */ /* Use conditional move to avoid a branch. */ - emit1(state, 0x48); - emit1(state, 0x0f); - emit1(state, 0x44); - emit1(state, 0xc1); /* cmove rax, rcx */ + emit_conditional_move(state, RCX, RAX); + if (sign) { + emit_pop(state, RCX); + /* handle DIV overflow */ + emit1(state, 0x9d); /* popfq */ + uint32_t jump_loc = state->offset; + emit_jcc_offset(state, 0x85); + emit_cmp_imm32(state, RCX, 0x80000000); + emit_conditional_move(state, RCX, RAX); + emit_jump_target_offset(state, JUMP_LOC, state->offset); + } } else { /* Restore dividend to RCX */ emit_pop(state, RCX); - /* Store the dividend in RAX if the divisor was zero. */ /* Use conditional move to avoid a branch. */ - emit1(state, 0x48); - emit1(state, 0x0f); - emit1(state, 0x44); - emit1(state, 0xd1); /* cmove rdx, rcx */ + emit_conditional_move(state, RCX, RDX); + if (sign) { + /* handle REM overflow */ + emit1(state, 0x9d); /* popfq */ + uint32_t jump_loc = state->offset; + emit_jcc_offset(state, 0x85); + emit_cmp_imm32(state, RCX, 0x80000000); + emit_load_imm(state, RCX, 0); + emit_conditional_move(state, RCX, RDX); + emit_jump_target_offset(state, JUMP_LOC, state->offset); + } } } @@ -1183,10 +1216,10 @@ static void muldivmod(struct jit_state *state, emit_dataproc_3source(state, true, DP3_MADD, dst, dst, src, RZ); break; case 0x38: - divmod(state, JIT_OP_DIV_REG, dst, dst, src); + divmod(state, JIT_OP_DIV_REG, dst, dst, src, sign); break; case 0x98: - divmod(state, JIT_OP_MOD_REG, dst, dst, src); + divmod(state, JIT_OP_MOD_REG, dst, dst, src, sign); break; default: __UNREACHABLE; diff --git a/src/rv32_template.c b/src/rv32_template.c index 07811ce1..ee7920c6 100644 --- a/src/rv32_template.c +++ b/src/rv32_template.c @@ -1212,7 +1212,7 @@ RVOP( map, VR2, rd; mov, VR1, TMP; mov, VR0, VR2; - div, 0x38, TMP, VR2, 0; + div, 0x38, TMP, VR2, 1; /* FIXME: handle overflow */ })) @@ -1260,7 +1260,7 @@ GEN({ map, VR2, rd; mov, VR1, TMP; mov, VR0, VR2; - mod, 0x98, TMP, VR2, 0; + mod, 0x98, TMP, VR2, 1; /* FIXME: handle overflow */ }))