Skip to content
22 changes: 22 additions & 0 deletions tests/unit/compiler/venom/test_algebraic_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,28 @@ def test_interleaved_case(interleave_point):
_check_pre_post(pre, post)


# TODO: enable when range analysis is available
# def test_fold_shifted_add_chain():
# pre = """
# main:
# %x = source
# %tmp0 = shl %x, 5
# %tmp1 = add 32, %tmp0
# %tmp2 = add 31, %tmp1
# %out = shr %tmp2, 5
# sink %out
# """

# post = """
# main:
# %x = source
# %out = add 1, %x
# sink %out
# """

# _check_pre_post(pre, post)


def test_offsets():
"""
Test of addition to offset rewrites
Expand Down
3 changes: 3 additions & 0 deletions vyper/venom/analysis/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
def get_uses(self, op: IRVariable) -> OrderedSet[IRInstruction]:
return self._dfg_inputs.get(op, OrderedSet())

def is_single_use(self, output: IRVariable) -> bool:
return len(self.get_uses(output)) == 1

def get_uses_in_bb(self, op: IRVariable, bb: IRBasicBlock):
"""
Get uses of a given variable in a specific basic block.
Expand Down
145 changes: 143 additions & 2 deletions vyper/venom/passes/algebraic_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
flip_comparison_opcode,
)
from vyper.venom.passes.base_pass import InstUpdater, IRPass
from vyper.venom.passes.sccp.eval import eval_arith

TRUTHY_INSTRUCTIONS = ("iszero", "jnz", "assert", "assert_unreachable")

Expand All @@ -19,6 +20,16 @@ def lit_eq(op: IROperand, val: int) -> bool:
return isinstance(op, IRLiteral) and wrap256(op.value) == wrap256(val)


def lit_add(op: IROperand, val: int) -> int:
assert isinstance(op, IRLiteral)
return eval_arith("add", [op, IRLiteral(val)])


def lit_sub(val: int, op: IROperand) -> int:
assert isinstance(op, IRLiteral)
return eval_arith("sub", [op, IRLiteral(val)])


class AlgebraicOptimizationPass(IRPass):
"""
This pass reduces algebraic evaluatable expressions.
Expand Down Expand Up @@ -105,6 +116,121 @@ def _handle_offset(self):
def _is_lit(self, operand: IROperand) -> bool:
return isinstance(operand, IRLiteral)

def _extract_value_and_literal_operands(
self, inst: IRInstruction
) -> tuple[IROperand | None, IRLiteral | None]:
value_op = None
literal_op = None
for op in inst.operands:
if self._is_lit(op):
if literal_op is not None:
return None, None
literal_op = op
else:
value_op = op
assert isinstance(literal_op, IRLiteral) or literal_op is None # help mypy
return value_op, literal_op

def _fold_add_chain(self, inst: IRInstruction) -> bool:
if inst.opcode not in {"add", "sub"}:
return False

op0, op1 = inst.operands
base_operand: IROperand | None = None
total = 0

if inst.opcode == "add":
base_operand, literal = self._extract_value_and_literal_operands(inst)
if literal is None or base_operand is None:
return False
total = lit_add(literal, total)
else: # sub
if self._is_lit(op0) and not self._is_lit(op1):
total = lit_sub(total, op0)
base_operand = op1
else:
return False

base_operand, traced = self._trace_add_chain(base_operand)
total += traced

if total == 0:
self.updater.mk_assign(inst, base_operand)
return True

self.updater.update(inst, "add", [base_operand, IRLiteral(total)])
return True

def _fold_shifted_add_chain(self, inst: IRInstruction, value_op: IROperand, shift: int) -> bool:
if shift <= 0 or shift >= 256:
return False

traced = self._trace_shifted_add_chain(value_op, shift)
if traced is None:
return False

base_op, total = traced
add_const = total >> shift
new_ops = [base_op, IRLiteral(add_const)]
self.updater.update(inst, "add", new_ops)
return True

def _trace_add_chain(self, operand: IROperand) -> tuple[IROperand, int]:
total = 0
current = operand

while isinstance(current, IRVariable):
producer = self.dfg.get_producing_instruction(current)
if producer is None:
break

if producer.opcode == "add":
assert producer.output is not None # help mypy
if not self.dfg.is_single_use(producer.output):
break

value_op, literal = self._extract_value_and_literal_operands(producer)
if literal is None or value_op is None:
break

assert isinstance(literal, IRLiteral) # help mypy
total = lit_add(literal, total)
current = value_op
continue

if producer.opcode == "sub":
assert producer.output is not None # help mypy
if not self.dfg.is_single_use(producer.output):
break
op0, op1 = producer.operands
if self._is_lit(op0) and not self._is_lit(op1):
total = lit_sub(total, op0)
current = op1
continue
break

break

return current, total

def _trace_shifted_add_chain(
self, operand: IROperand, shift: int
) -> tuple[IROperand, int] | None:
base_operand, total = self._trace_add_chain(operand)

if not isinstance(base_operand, IRVariable):
return None

producer = self.dfg.get_producing_instruction(base_operand)
if producer is None or producer.opcode != "shl":
return None

value_op, shl_shift = self._extract_value_and_literal_operands(producer)
if shl_shift is None or value_op is None or shl_shift.value != shift:
return None

return value_op, total

def _algebraic_opt(self):
self._algebraic_opt_pass()

Expand Down Expand Up @@ -149,10 +275,21 @@ def _handle_inst_peephole(self, inst: IRInstruction):
operands = inst.operands

if inst.opcode in {"shl", "shr", "sar"}:
value_op, shift_lit = self._extract_value_and_literal_operands(inst)
if shift_lit is None or value_op is None:
return
# (x >> 0) == (x << 0) == x
if lit_eq(operands[1], 0):
self.updater.mk_assign(inst, operands[0])
if lit_eq(shift_lit, 0):
self.updater.mk_assign(inst, value_op)
return
#
# Disabled for now -- we need to know literal ranges to do this safely
#
# if inst.opcode == "shr" and self._fold_shifted_add_chain(
# inst, value_op, shift_lit.value
# ):
# return

# no more cases for these instructions
return

Expand Down Expand Up @@ -204,6 +341,10 @@ def _handle_inst_peephole(self, inst: IRInstruction):
self.updater.update(inst, "not", [operands[1]])
return

if inst.opcode in {"add", "sub"}:
if self._fold_add_chain(inst):
return

return

# x & 0xFF..FF -> x
Expand Down
Loading