diff --git a/boa/coverage.py b/boa/coverage.py index 641bc17c..4257482a 100644 --- a/boa/coverage.py +++ b/boa/coverage.py @@ -2,9 +2,8 @@ import coverage.plugin import vyper.ast as vy_ast -from vyper.ir import compile_ir +from vyper.ast.parse import parse_to_ast -import boa.interpret from boa.contracts.vyper.ast_utils import get_fn_ancestor_from_node from boa.environment import Env @@ -40,7 +39,7 @@ def __init__(self): # coverage.py requires us to inspect the python call frame to # see what line number to produce. we hook into specially crafted - # Env._hook_trace_pc which is called for every pc if coverage is + # Env._trace_cov which is called for every unique pc if coverage is # enabled, and then back out the contract and lineno information # from there. @@ -48,47 +47,31 @@ def _valid_frame(self, frame): if hasattr(frame.f_code, "co_qualname"): # Python>=3.11 code_qualname = frame.f_code.co_qualname - return code_qualname == Env._hook_trace_computation.__qualname__ + return code_qualname == Env._trace_cov.__qualname__ else: # in Python<3.11 we don't have co_qualname, so try hard to # find a match anyways. (this might fail if for some reason - # the executing env has a monkey-patched _hook_trace_computation + # the executing env has a monkey-patched _trace_cov # or something) env = Env.get_singleton() - return frame.f_code == env._hook_trace_computation.__code__ - - def _contract_for_frame(self, frame): - if not self._valid_frame(frame): - return None - return frame.f_locals["contract"] + return frame.f_code == env._trace_cov.__code__ def dynamic_source_filename(self, filename, frame): - contract = self._contract_for_frame(frame) - if contract is None or contract.filename is None: + if not self._valid_frame(frame): return None - - return str(contract.filename) + return frame.f_locals["filename"] def has_dynamic_source_filename(self): return True # https://coverage.rtfd.io/en/stable/api_plugin.html#coverage.FileTracer.line_number_range def line_number_range(self, frame): - contract = self._contract_for_frame(frame) - if contract is None: - return (-1, -1) - - if (pc := frame.f_locals.get("_pc")) is None: + if not self._valid_frame(frame): return (-1, -1) - pc_map = contract.source_map["pc_raw_ast_map"] - - node = pc_map.get(pc) - if node is None: - return (-1, -1) + start_lineno = frame.f_locals["lineno"] - start_lineno = node.lineno # end_lineno = node.end_lineno # note: `return start_lineno, end_lineno` doesn't seem to work. return start_lineno, start_lineno @@ -99,7 +82,7 @@ def dynamic_context(self, frame): # helper function. null returns get optimized directly into a jump -# to function cleanup which maps to the parnet FunctionDef ast. +# to function cleanup which maps to the parent FunctionDef ast. def _is_null_return(ast_node): match ast_node: case vy_ast.Return(value=None): @@ -112,66 +95,68 @@ def __init__(self, filename, env=None): super().__init__(filename) @cached_property - def _compiler_data(self): - return boa.interpret.compiler_data(self.source(), self.filename) + def _ast(self): + return parse_to_ast(self.source()) def arcs(self): ret = set() - for ast_node in self._compiler_data.vyper_module.get_descendants(): - if isinstance(ast_node, vy_ast.If): - fn_node = get_fn_ancestor_from_node(ast_node) - - # one arc is directly into the body - arc_true = ast_node.body[0].lineno - if _is_null_return(ast_node.body[0]): - arc_true = fn_node.lineno - ret.add((ast_node.lineno, arc_true)) - - # the other arc is to the end of the if statement - # try hard to find the next executable line. - children = ast_node._parent.get_children() - for node, next_ in zip(children, children[1:]): - if id(node) == id(ast_node): - arc_false = next_.lineno - break - else: - # the if stmt was the last stmt in the enclosing scope. - arc_false = ast_node._parent.end_lineno + 1 - - # unless there is an else or elif. then the other - # arc is to the else/elif statement. - if ast_node.orelse: - arc_false = ast_node.orelse[0].lineno - - # return cases: - # if it's past the end of the fn it's an implicit return - if arc_false > fn_node.end_lineno: - arc_false = fn_node.lineno - # or it's an explicit return - if ast_node.orelse and _is_null_return(ast_node.orelse[0]): - arc_false = fn_node.lineno - - ret.add((ast_node.lineno, arc_false)) + for ast_node in self._ast.get_descendants(vy_ast.If): + fn_node = get_fn_ancestor_from_node(ast_node) + + # one arc is directly into the body + arc_true = ast_node.body[0].lineno + if _is_null_return(ast_node.body[0]): + arc_true = fn_node.lineno + ret.add((ast_node.lineno, arc_true)) + + # the other arc is to the end of the if statement + # try hard to find the next executable line. + children = ast_node._parent.get_children() + for node, next_ in zip(children, children[1:]): + if id(node) == id(ast_node): + arc_false = next_.lineno + break + else: + # the if stmt was the last stmt in the enclosing scope. + arc_false = ast_node._parent.end_lineno + 1 + + # unless there is an else or elif. then the other + # arc is to the else/elif statement. + if ast_node.orelse: + arc_false = ast_node.orelse[0].lineno + + # return cases: + # if it's past the end of the fn it's an implicit return + if arc_false > fn_node.end_lineno: + arc_false = fn_node.lineno + # or it's an explicit return + if ast_node.orelse and _is_null_return(ast_node.orelse[0]): + arc_false = fn_node.lineno + + ret.add((ast_node.lineno, arc_false)) return ret def exit_counts(self): ret = {} - for ast_node in self._compiler_data.vyper_module.get_descendants(vy_ast.If): + for ast_node in self._ast: ret[ast_node.lineno] = 2 return ret @cached_property def _lines(self): ret = set() - c = self._compiler_data - # source_map should really be in CompilerData - _, source_map = compile_ir.assembly_to_evm(c.assembly_runtime) + functions = self._ast.get_children(vy_ast.FunctionDef) - for node in source_map["pc_raw_ast_map"].values(): - ret.add(node.lineno) + for f in functions: + # add entry to the function for external functions? + # ret.add(f.lineno) + for stmt in f.body: + ret.add(stmt.lineno) + for node in stmt.get_descendants(): + ret.add(node.lineno) return ret diff --git a/boa/environment.py b/boa/environment.py index 526b3c9a..3e4d8c8d 100644 --- a/boa/environment.py +++ b/boa/environment.py @@ -295,24 +295,39 @@ def execute_code( contract=contract, ) if self._coverage_enabled: - self._hook_trace_computation(ret, contract) + self._trace_computation(ret, contract) if ret._gas_meter_class != NoGasMeter: self._update_gas_used(ret.get_gas_used()) return ret - def _hook_trace_computation(self, computation, contract=None): - # XXX perf: don't trace if contract is None - for _pc in computation.code._trace: - # loop over pc so that it is available when coverage hooks into it - pass + # trace pcs for coverage sake. dummy function which + # just issues the right calls to _trace_cov() to get picked + # up by coverage. bit ugly, but tracer only allows + # dynamic_source_filename to be set once per (python) function call, + # so we need to use this in case the pc trace covers multiple files + def _trace_computation(self, computation, contract=None): + # perf: don't trace if contract is None + if contract is not None: + ast_map = contract.source_map["pc_raw_ast_map"] + seen_pcs = set() + for pc in computation.code._trace: + if pc in seen_pcs: + continue + if (node := ast_map.get(pc)) is not None: + mod = node.module_node + self._trace_cov(mod.path, node.lineno) + seen_pcs.add(pc) for child in computation.children: if child.msg.code_address == b"": continue child_contract = self._lookup_contract_fast(child.msg.code_address) - self._hook_trace_computation(child, child_contract) + self._trace_computation(child, child_contract) + + def _trace_cov(self, filename, lineno): + pass def get_code(self, address: _AddressType) -> bytes: return self.evm.get_code(Address(address)) diff --git a/boa/interpret.py b/boa/interpret.py index 4cf70e9c..7ab96a90 100644 --- a/boa/interpret.py +++ b/boa/interpret.py @@ -159,7 +159,7 @@ def get_compiler_data(): _ = ret.bytecode, ret.bytecode_runtime return ret - assert isinstance(deployer, type) + assert isinstance(deployer, type) or deployer is None deployer_id = repr(deployer) # a unique str identifying the deployer class cache_key = str((contract_name, fingerprint, kwargs, deployer_id)) return _disk_cache.caching_lookup(cache_key, get_compiler_data)