diff --git a/tests/unit/compiler/venom/test_algebraic_optimizer.py b/tests/unit/compiler/venom/test_algebraic_optimizer.py index 9cf7c1d61b..770b827961 100644 --- a/tests/unit/compiler/venom/test_algebraic_optimizer.py +++ b/tests/unit/compiler/venom/test_algebraic_optimizer.py @@ -202,6 +202,28 @@ def test_interleaved_case(interleave_point): _check_pre_post(pre, post) +# TODO: enable when range analysis is available +# def test_fold_shifted_add_chain(): +# pre = """ +# main: +# %x = source +# %tmp0 = shl %x, 5 +# %tmp1 = add 32, %tmp0 +# %tmp2 = add 31, %tmp1 +# %out = shr %tmp2, 5 +# sink %out +# """ + +# post = """ +# main: +# %x = source +# %out = add 1, %x +# sink %out +# """ + +# _check_pre_post(pre, post) + + def test_offsets(): """ Test of addition to offset rewrites diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 1fa747e926..465f03600d 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -20,6 +20,9 @@ def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): def get_uses(self, op: IRVariable) -> OrderedSet[IRInstruction]: return self._dfg_inputs.get(op, OrderedSet()) + def is_single_use(self, output: IRVariable) -> bool: + return len(self.get_uses(output)) == 1 + def get_uses_in_bb(self, op: IRVariable, bb: IRBasicBlock): """ Get uses of a given variable in a specific basic block. diff --git a/vyper/venom/passes/algebraic_optimization.py b/vyper/venom/passes/algebraic_optimization.py index a4a7cd6be7..f719176c71 100644 --- a/vyper/venom/passes/algebraic_optimization.py +++ b/vyper/venom/passes/algebraic_optimization.py @@ -11,6 +11,7 @@ flip_comparison_opcode, ) from vyper.venom.passes.base_pass import InstUpdater, IRPass +from vyper.venom.passes.sccp.eval import eval_arith TRUTHY_INSTRUCTIONS = ("iszero", "jnz", "assert", "assert_unreachable") @@ -19,6 +20,16 @@ def lit_eq(op: IROperand, val: int) -> bool: return isinstance(op, IRLiteral) and wrap256(op.value) == wrap256(val) +def lit_add(op: IROperand, val: int) -> int: + assert isinstance(op, IRLiteral) + return eval_arith("add", [op, IRLiteral(val)]) + + +def lit_sub(val: int, op: IROperand) -> int: + assert isinstance(op, IRLiteral) + return eval_arith("sub", [op, IRLiteral(val)]) + + class AlgebraicOptimizationPass(IRPass): """ This pass reduces algebraic evaluatable expressions. @@ -105,6 +116,121 @@ def _handle_offset(self): def _is_lit(self, operand: IROperand) -> bool: return isinstance(operand, IRLiteral) + def _extract_value_and_literal_operands( + self, inst: IRInstruction + ) -> tuple[IROperand | None, IRLiteral | None]: + value_op = None + literal_op = None + for op in inst.operands: + if self._is_lit(op): + if literal_op is not None: + return None, None + literal_op = op + else: + value_op = op + assert isinstance(literal_op, IRLiteral) or literal_op is None # help mypy + return value_op, literal_op + + def _fold_add_chain(self, inst: IRInstruction) -> bool: + if inst.opcode not in {"add", "sub"}: + return False + + op0, op1 = inst.operands + base_operand: IROperand | None = None + total = 0 + + if inst.opcode == "add": + base_operand, literal = self._extract_value_and_literal_operands(inst) + if literal is None or base_operand is None: + return False + total = lit_add(literal, total) + else: # sub + if self._is_lit(op0) and not self._is_lit(op1): + total = lit_sub(total, op0) + base_operand = op1 + else: + return False + + base_operand, traced = self._trace_add_chain(base_operand) + total += traced + + if total == 0: + self.updater.mk_assign(inst, base_operand) + return True + + self.updater.update(inst, "add", [base_operand, IRLiteral(total)]) + return True + + def _fold_shifted_add_chain(self, inst: IRInstruction, value_op: IROperand, shift: int) -> bool: + if shift <= 0 or shift >= 256: + return False + + traced = self._trace_shifted_add_chain(value_op, shift) + if traced is None: + return False + + base_op, total = traced + add_const = total >> shift + new_ops = [base_op, IRLiteral(add_const)] + self.updater.update(inst, "add", new_ops) + return True + + def _trace_add_chain(self, operand: IROperand) -> tuple[IROperand, int]: + total = 0 + current = operand + + while isinstance(current, IRVariable): + producer = self.dfg.get_producing_instruction(current) + if producer is None: + break + + if producer.opcode == "add": + assert producer.output is not None # help mypy + if not self.dfg.is_single_use(producer.output): + break + + value_op, literal = self._extract_value_and_literal_operands(producer) + if literal is None or value_op is None: + break + + assert isinstance(literal, IRLiteral) # help mypy + total = lit_add(literal, total) + current = value_op + continue + + if producer.opcode == "sub": + assert producer.output is not None # help mypy + if not self.dfg.is_single_use(producer.output): + break + op0, op1 = producer.operands + if self._is_lit(op0) and not self._is_lit(op1): + total = lit_sub(total, op0) + current = op1 + continue + break + + break + + return current, total + + def _trace_shifted_add_chain( + self, operand: IROperand, shift: int + ) -> tuple[IROperand, int] | None: + base_operand, total = self._trace_add_chain(operand) + + if not isinstance(base_operand, IRVariable): + return None + + producer = self.dfg.get_producing_instruction(base_operand) + if producer is None or producer.opcode != "shl": + return None + + value_op, shl_shift = self._extract_value_and_literal_operands(producer) + if shl_shift is None or value_op is None or shl_shift.value != shift: + return None + + return value_op, total + def _algebraic_opt(self): self._algebraic_opt_pass() @@ -149,10 +275,21 @@ def _handle_inst_peephole(self, inst: IRInstruction): operands = inst.operands if inst.opcode in {"shl", "shr", "sar"}: + value_op, shift_lit = self._extract_value_and_literal_operands(inst) + if shift_lit is None or value_op is None: + return # (x >> 0) == (x << 0) == x - if lit_eq(operands[1], 0): - self.updater.mk_assign(inst, operands[0]) + if lit_eq(shift_lit, 0): + self.updater.mk_assign(inst, value_op) return + # + # Disabled for now -- we need to know literal ranges to do this safely + # + # if inst.opcode == "shr" and self._fold_shifted_add_chain( + # inst, value_op, shift_lit.value + # ): + # return + # no more cases for these instructions return @@ -204,6 +341,10 @@ def _handle_inst_peephole(self, inst: IRInstruction): self.updater.update(inst, "not", [operands[1]]) return + if inst.opcode in {"add", "sub"}: + if self._fold_add_chain(inst): + return + return # x & 0xFF..FF -> x