Skip to content

Commit

Permalink
feat: add optimizer guardrails, refactor (#2914)
Browse files Browse the repository at this point in the history
this commit factors out and generalizes the optimizer logic for
comparison operators. it clarifies the logic for dealing with boundary
cases, and cleans up the control flow in the optimizer in general. it
also adds an assembly peephole optimization to help optimize the
assembly generated by the comparator pass.

this commit also introduces a new IR keyword, "unique_symbol". it
functions as a guardrail to help ensure optimizer passes are sane. it is
a statement that codegen can insert into the IR, and sanity checks will
be performed to make sure that the statement is not optimized out, and
in some cases, to ensure that the statement shows up in the optimized
IR. it can be thought of as a very primitive effects tracking framework.

lastly, some tests are added to increase optimizer coverage.
  • Loading branch information
charles-cooper authored Jun 21, 2022
1 parent ddea185 commit 2fddbde
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 196 deletions.
94 changes: 84 additions & 10 deletions tests/compiler/ir/test_optimize_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
(["eq", 1, 2], [0]),
(["lt", 1, 2], [1]),
(["eq", "x", 0], ["iszero", "x"]),
(["ne", "x", 0], ["iszero", ["iszero", "x"]]),
(["ne", "x", 1], None),
(["iszero", ["ne", "x", 1]], ["iszero", ["iszero", ["iszero", ["xor", "x", 1]]]]),
(["eq", ["sload", 0], 0], ["iszero", ["sload", 0]]),
# branch pruner
(["if", ["eq", 1, 2], "pass"], ["seq"]),
Expand All @@ -16,31 +19,89 @@
(["seq", ["assert", ["lt", 1, 2]]], ["seq"]),
(["seq", ["assert", ["lt", 1, 2]], 2], [2]),
# condition rewriter
(["if", ["eq", "x", "y"], "pass"], ["if", ["iszero", ["sub", "x", "y"]], "pass"]),
(["if", ["eq", "x", "y"], "pass"], ["if", ["iszero", ["xor", "x", "y"]], "pass"]),
(["if", "cond", 1, 0], ["if", ["iszero", "cond"], 0, 1]),
(["assert", ["eq", "x", "y"]], ["assert", ["iszero", ["sub", "x", "y"]]]),
(["assert", ["eq", "x", "y"]], ["assert", ["iszero", ["xor", "x", "y"]]]),
# nesting
(["mstore", 0, ["eq", 1, 2]], ["mstore", 0, 0]),
# conditions
(["ge", "x", 0], [1]), # x >= 0 == True
(["ge", ["sload", 0], 0], None), # no-op
(["iszero", ["gt", "x", 2 ** 256 - 1]], [1]), # x >= MAX_UINT256 == False
(["iszero", ["sgt", "x", 2 ** 255 - 1]], [1]), # signed x >= MAX_INT256 == False
(["gt", "x", 2 ** 256 - 1], [0]), # x >= MAX_UINT256 == False
# (x > 0) => x == 0
(["iszero", ["gt", "x", 0]], ["iszero", ["iszero", ["iszero", "x"]]]),
# !(x < MAX_UINT256) => x == MAX_UINT256
(["iszero", ["lt", "x", 2 ** 256 - 1]], ["iszero", ["iszero", ["iszero", ["not", "x"]]]]),
# !(x < MAX_INT256) => x == MAX_INT256
(
["iszero", ["slt", "x", 2 ** 255 - 1]],
["iszero", ["iszero", ["iszero", ["xor", "x", 2 ** 255 - 1]]]],
),
# !(x > MIN_INT256) => x == MIN_INT256
(
["iszero", ["sgt", "x", -(2 ** 255)]],
["iszero", ["iszero", ["iszero", ["xor", "x", -(2 ** 255)]]]],
),
(["sgt", "x", 2 ** 255 - 1], [0]), # signed x > MAX_INT256 == False
(["sge", "x", 2 ** 255 - 1], ["eq", "x", 2 ** 255 - 1]),
(["eq", -1, "x"], ["iszero", ["not", "x"]]),
(["iszero", ["eq", -1, "x"]], ["iszero", ["iszero", ["not", "x"]]]),
(["le", "x", 0], ["iszero", "x"]),
(["le", 0, "x"], [1]),
(["le", 0, ["sload", 0]], None), # no-op
(["ge", "x", 0], [1]),
# boundary conditions
(["slt", "x", -(2 ** 255)], [0]),
(["sle", "x", -(2 ** 255)], ["eq", "x", -(2 ** 255)]),
(["lt", "x", 2 ** 256 - 1], None),
(["le", "x", 2 ** 256 - 1], [1]),
(["gt", "x", 0], ["iszero", ["iszero", "x"]]),
# x < 0 => false
(["lt", "x", 0], [0]),
# 0 < x => x != 0
(["lt", 0, "x"], ["iszero", ["iszero", "x"]]),
(["gt", 5, "x"], None),
(["ge", 5, "x"], None),
# x < 1 => x == 0
(["lt", "x", 1], ["iszero", "x"]),
(["slt", "x", 1], None),
(["gt", "x", 1], None),
(["sgt", "x", 1], None),
(["gt", "x", 2 ** 256 - 2], ["iszero", ["not", "x"]]),
(["lt", "x", 2 ** 256 - 2], None),
(["slt", "x", 2 ** 256 - 2], None),
(["sgt", "x", 2 ** 256 - 2], None),
(["slt", "x", -(2 ** 255) + 1], ["eq", "x", -(2 ** 255)]),
(["sgt", "x", -(2 ** 255) + 1], None),
(["lt", "x", -(2 ** 255) + 1], None),
(["gt", "x", -(2 ** 255) + 1], None),
(["sgt", "x", 2 ** 255 - 2], ["eq", "x", 2 ** 255 - 1]),
(["slt", "x", 2 ** 255 - 2], None),
(["gt", "x", 2 ** 255 - 2], None),
(["lt", "x", 2 ** 255 - 2], None),
# 5 > x; x < 5; x <= 4
(["iszero", ["gt", 5, "x"]], ["iszero", ["le", "x", 4]]),
(["iszero", ["ge", 5, "x"]], None),
# 5 >= x; x <= 5; x < 6
(["ge", 5, "x"], ["lt", "x", 6]),
(["lt", 5, "x"], None),
(["le", 5, "x"], None),
# 5 < x; x > 5; x >= 6
(["iszero", ["lt", 5, "x"]], ["iszero", ["ge", "x", 6]]),
(["iszero", ["le", 5, "x"]], None),
# 5 <= x; x >= 5; x > 4
(["le", 5, "x"], ["gt", "x", 4]),
(["sgt", 5, "x"], None),
(["sge", 5, "x"], None),
# 5 > x; x < 5; x <= 4
(["iszero", ["sgt", 5, "x"]], ["iszero", ["sle", "x", 4]]),
(["iszero", ["sge", 5, "x"]], None),
# 5 >= x; x <= 5; x < 6
(["sge", 5, "x"], ["slt", "x", 6]),
(["slt", 5, "x"], None),
(["sle", 5, "x"], None),
(["slt", "x", -(2 ** 255)], ["slt", "x", -(2 ** 255)]), # unimplemented
# tricky conditions
# 5 < x; x > 5; x >= 6
(["iszero", ["slt", 5, "x"]], ["iszero", ["sge", "x", 6]]),
(["iszero", ["sle", 5, "x"]], None),
# 5 <= x; x >= 5; x > 4
(["sle", 5, "x"], ["sgt", "x", 4]),
# tricky constant folds
(["sgt", 2 ** 256 - 1, 0], [0]), # -1 > 0
(["gt", 2 ** 256 - 1, 0], [1]), # -1 > 0
(["gt", 2 ** 255, 0], [1]), # 0x80 > 0
Expand All @@ -54,12 +115,21 @@
(["sgt", -(2 ** 255), 2 ** 255], [0]), # 0x80 > 0x80
(["slt", 2 ** 255, -(2 ** 255)], [0]), # 0x80 < 0x80
# arithmetic
(["ceil32", "x"], None),
(["ceil32", 0], [0]),
(["ceil32", 1], [32]),
(["ceil32", 32], [32]),
(["ceil32", 33], [64]),
(["ceil32", 95], [96]),
(["ceil32", 96], [96]),
(["ceil32", 97], [128]),
(["add", "x", 0], ["x"]),
(["add", 0, "x"], ["x"]),
(["sub", "x", 0], ["x"]),
(["sub", "x", "x"], [0]),
(["sub", ["sload", 0], ["sload", 0]], None),
(["sub", ["callvalue"], ["callvalue"]], None),
(["sub", -1, ["sload", 0]], ["not", ["sload", 0]]),
(["mul", "x", 1], ["x"]),
(["div", "x", 1], ["x"]),
(["sdiv", "x", 1], ["x"]),
Expand Down Expand Up @@ -90,6 +160,9 @@
(["exp", 1, "x"], [1]),
(["exp", 0, "x"], ["iszero", "x"]),
# bitwise ops
(["xor", "x", 2 ** 256 - 1], ["not", "x"]),
(["and", "x", 2 ** 256 - 1], ["x"]),
(["or", "x", 2 ** 256 - 1], [2 ** 256 - 1]),
(["shr", 0, "x"], ["x"]),
(["sar", 0, "x"], ["x"]),
(["shl", 0, "x"], ["x"]),
Expand Down Expand Up @@ -137,6 +210,7 @@ def test_ir_optimizer(ir):
else:
expected = IRnode.from_list(ir[1])
expected.repr_show_gas = True
optimized.annotation = None
assert optimized == expected


Expand Down
10 changes: 8 additions & 2 deletions vyper/builtin_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vyper.codegen.core import (
STORE,
IRnode,
_freshname,
add_ofst,
bytes_data_ptr,
calculate_type_for_external_return,
Expand All @@ -24,6 +25,7 @@
clamp_nonzero,
copy_bytes,
ensure_in_memory,
eval_once_check,
eval_seq,
get_bytearray_length,
get_element_ptr,
Expand Down Expand Up @@ -1247,7 +1249,9 @@ class SelfDestruct(BuiltinFunction):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
context.check_is_not_constant("selfdestruct", expr)
return IRnode.from_list(["selfdestruct", args[0]])
return IRnode.from_list(
["seq", eval_once_check(_freshname("selfdestruct")), ["selfdestruct", args[0]]]
)


class BlockHash(BuiltinFunction):
Expand Down Expand Up @@ -1593,7 +1597,9 @@ def _create_ir(value, buf, length, salt=None, checked=True):
create_op = "create2"
args.append(salt)

ret = IRnode.from_list([create_op, *args])
ret = IRnode.from_list(
["seq", eval_once_check(_freshname("create_builtin")), [create_op, *args]]
)

if not checked:
return ret
Expand Down
18 changes: 17 additions & 1 deletion vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,24 @@ def LOAD(ptr: IRnode) -> IRnode:
return IRnode.from_list([op, ptr])


def eval_once_check(name):
# an IRnode which enforces uniqueness. include with a side-effecting
# operation to sanity check that the codegen pipeline only generates
# the side-effecting operation once (otherwise, IR-to-assembly will
# throw a duplicate label exception). there is no runtime overhead
# since the jumpdest gets optimized out in the final stage of assembly.
return IRnode.from_list(["unique_symbol", name])


def STORE(ptr: IRnode, val: IRnode) -> IRnode:
if ptr.location is None:
raise CompilerPanic("cannot dereference non-pointer type")
op = ptr.location.store_op
if op is None:
raise CompilerPanic(f"unreachable {ptr.location}") # pragma: notest
return IRnode.from_list([op, ptr, val])

_check = _freshname(f"{op}_")
return IRnode.from_list(["seq", eval_once_check(_check), [op, ptr, val]])


# Unwrap location
Expand Down Expand Up @@ -707,6 +718,11 @@ def _freshname(name):
return f"{name}{_label}"


def reset_names():
global _label
_label = 0


# returns True if t is ABI encoded and is a type that needs any kind of
# validation
def needs_clamp(t, encoding):
Expand Down
7 changes: 7 additions & 0 deletions vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from vyper.address_space import MEMORY
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.core import (
_freshname,
calculate_type_for_external_return,
check_assign,
check_external_call,
dummy_node_for_type,
eval_once_check,
make_setter,
needs_clamp,
unwrap_location,
Expand Down Expand Up @@ -178,6 +180,11 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con

ret = ["seq"]

# this is a sanity check to prevent double evaluation of the external call
# in the codegen pipeline. if the external call gets doubly evaluated,
# a duplicate label exception will get thrown during assembly.
ret.append(eval_once_check(_freshname(call_expr.node_source_code)))

buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context)

ret_unpacker, ret_ofst, ret_len = _unpack_returndata(
Expand Down
27 changes: 27 additions & 0 deletions vyper/codegen/ir_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ def _check(condition, err):
raise CodegenPanic(f"2nd argument to label must be var_list, {self}")
self.valency = 0
self._gas = 1 + sum(t.gas for t in self.args)
elif self.value == "unique_symbol":
# a label which enforces uniqueness, and does not appear
# in generated bytecode. this is useful for generating
# internal assertions that a particular IR fragment only
# occurs a single time in a program. note that unique_symbol
# must be distinct from all `unique_symbol`s AS WELL AS all
# `label`s, otherwise IR-to-assembly will raise an exception.
self.valency = 0
self._gas = 0

# var_list names a variable number stack variables
elif self.value == "var_list":
for arg in self.args:
Expand All @@ -291,6 +301,7 @@ def _check(condition, err):
self._gas = sum([arg.gas for arg in self.args])
elif self.value == "deploy":
self.valency = 0
_check(len(self.args) == 3, f"`deploy` should have three args {self}")
self._gas = NullAttractor() # unknown
# Stack variables
else:
Expand Down Expand Up @@ -324,6 +335,22 @@ def is_complex_ir(self):
and self.value.lower() not in do_not_cache
)

# unused, but might be useful for analysis at some point
def unique_symbols(self):
ret = set()
if self.value == "unique_symbol":
ret.add(self.args[0].value)

children = self.args
if self.value == "deploy":
children = [self.args[0], self.args[2]]
for arg in children:
s = arg.unique_symbols()
non_uniques = ret.intersection(s)
assert len(non_uniques) == 0, f"non-unique symbols {non_uniques}"
ret |= s
return ret

@property
def is_literal(self):
return isinstance(self.value, int) or self.value == "multi"
Expand Down
6 changes: 4 additions & 2 deletions vyper/codegen/self_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from vyper.address_space import MEMORY
from vyper.codegen.core import make_setter
from vyper.codegen.core import _freshname, eval_once_check, make_setter
from vyper.codegen.ir_node import IRnode, push_label_to_stack
from vyper.codegen.types import TupleType
from vyper.exceptions import StateAccessViolation, StructureException
Expand Down Expand Up @@ -91,7 +91,9 @@ def ir_for_self_call(stmt_expr, context):
# pass return label to subroutine
goto_op += [push_label_to_stack(return_label)]

call_sequence = ["seq", copy_args, goto_op, ["label", return_label, ["var_list"], "pass"]]
call_sequence = ["seq"]
call_sequence.append(eval_once_check(_freshname(stmt_expr.node_source_code)))
call_sequence.extend([copy_args, goto_op, ["label", return_label, ["var_list"], "pass"]])
if return_buffer is not None:
# push return buffer location to stack
call_sequence += [return_buffer]
Expand Down
6 changes: 5 additions & 1 deletion vyper/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Sequence, Union

from vyper.compiler import output
import vyper.ast as vy_ast # break an import cycle
import vyper.codegen.core as codegen
import vyper.compiler.output as output
from vyper.compiler.phases import CompilerData
from vyper.evm.opcodes import DEFAULT_EVM_VERSION, evm_wrapper
from vyper.typing import (
Expand Down Expand Up @@ -112,6 +114,8 @@ def compile_codes(
):
interfaces = interfaces[contract_name]

# make IR output the same between runs
codegen.reset_names()
compiler_data = CompilerData(
source_code,
contract_name,
Expand Down
2 changes: 1 addition & 1 deletion vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _build_asm(asm_list):
for node in asm_list:

if isinstance(node, list):
output_string += "[ " + _build_asm(node) + "] "
output_string += "{ " + _build_asm(node) + "} "
continue

if in_push > 0:
Expand Down
Loading

0 comments on commit 2fddbde

Please # to comment.