From 2944b2ea085e2fa8ba62a0edd1276a97108606ee Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 07:04:28 -0400 Subject: [PATCH 01/14] wip: pull import resolution into separate pass --- vyper/semantics/analysis/import_graph.py | 37 --- vyper/semantics/analysis/imports.py | 290 +++++++++++++++++++++++ vyper/semantics/analysis/module.py | 269 +++------------------ 3 files changed, 324 insertions(+), 272 deletions(-) delete mode 100644 vyper/semantics/analysis/import_graph.py create mode 100644 vyper/semantics/analysis/imports.py diff --git a/vyper/semantics/analysis/import_graph.py b/vyper/semantics/analysis/import_graph.py deleted file mode 100644 index e406878194..0000000000 --- a/vyper/semantics/analysis/import_graph.py +++ /dev/null @@ -1,37 +0,0 @@ -import contextlib -from dataclasses import dataclass, field -from typing import Iterator - -from vyper import ast as vy_ast -from vyper.exceptions import CompilerPanic, ImportCycle - -""" -data structure for collecting import statements and validating the -import graph -""" - - -@dataclass -class ImportGraph: - # the current path in the import graph traversal - _path: list[vy_ast.Module] = field(default_factory=list) - - def push_path(self, module_ast: vy_ast.Module) -> None: - if module_ast in self._path: - cycle = self._path + [module_ast] - raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) - - self._path.append(module_ast) - - def pop_path(self, expected: vy_ast.Module) -> None: - popped = self._path.pop() - if expected != popped: - raise CompilerPanic("unreachable") - - @contextlib.contextmanager - def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: - self.push_path(module_ast) - try: - yield - finally: - self.pop_path(module_ast) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py new file mode 100644 index 0000000000..9fdddf8266 --- /dev/null +++ b/vyper/semantics/analysis/imports.py @@ -0,0 +1,290 @@ +import contextlib +import os +from dataclasses import dataclass, field +from pathlib import Path, PurePath +from typing import Any, Iterator + +import vyper +from vyper import ast as vy_ast +from vyper.compiler.input_bundle import ( + ABIInput, + CompilerInput, + FileInput, + FilesystemInputBundle, + InputBundle, + PathLike, +) +from vyper.exceptions import ( + CompilerPanic, + DuplicateImport, + ImportCycle, + ModuleNotFound, + StructureException, +) + +""" +collect import statements and validate the import graph +""" + + +@dataclass +class _ImportGraph: + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + def push_path(self, module_ast: vy_ast.Module) -> None: + if module_ast in self._path: + cycle = self._path + [module_ast] + raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) + + self._path.append(module_ast) + + def pop_path(self, expected: vy_ast.Module) -> None: + popped = self._path.pop() + if expected != popped: + raise CompilerPanic("unreachable") + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) + + +class ImportAnalyzer: + def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): + self.graph = graph + self._ast_of: dict[PathLike, vy_ast.Module] = {} + + self.integrity_sum = None + + def resolve_imports(self, module_ast: vy_ast.Module): + self._resolve_imports_r(module_ast) + self.integrity_sum = self._calculate_integrity_sum(module_ast) + + def _resolve_imports_r(self, module_ast: vy_ast.Module): + with self.graph.enter_path(module_ast): + for node in module_ast.body: + if isinstance(node, vy_ast.Import): + self._handle_Import(node) + elif isinstance(node, vy_ast.ImportFrom): + self._handle_ImportFrom(node) + + def _handle_Import(self, node: vy_ast.Import): + # import x.y[name] as y[alias] + + alias = node.alias + + if alias is None: + alias = node.name + + # don't handle things like `import x.y` + if "." in alias: + msg = "import requires an accompanying `as` statement" + suggested_alias = node.name[node.name.rfind(".") :] + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) + + self._add_import(node, 0, node.name, alias) + + def _handle_ImportFrom(self, node: vy_ast.ImportFrom): + # from m.n[module] import x[name] as y[alias] + alias = node.alias or node.name + + module = node.module or "" + if module: + module += "." + + qualified_module_name = module + node.name + self._add_import(node, node.level, qualified_module_name, alias) + + def _add_import( + self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str + ) -> None: + compiler_input, ast = self._load_import(node, level, qualified_module_name, alias) + node._metadata["compiler_input"] = compiler_input + node._metadata["imported_ast"] = ast + node._metadata["alias"] = alias + + # load an InterfaceT or ModuleInfo from an import. + # raises FileNotFoundError + def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: + # the directory this (currently being analyzed) module is in + self_search_path = Path(self.ast.resolved_path).parent + + with self.input_bundle.poke_search_path(self_search_path): + return self._load_import_helper(node, level, module_str, alias) + + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_ast(file) + + return ast_of[file.source_id] + + def _load_import_helper( + self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str + ) -> tuple[CompilerInput, Any]: + if _is_builtin(module_str): + return _load_builtin_import(level, module_str) + + path = _import_to_path(level, module_str) + + if path in self.graph._imported_modules: + previous_import_stmt = self._imported_modules[path] + raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + + self._imported_modules[path] = node + + err = None + + try: + path_vy = path.with_suffix(".vy") + file = self.input_bundle.load_file(path_vy) + assert isinstance(file, FileInput) # mypy hint + + module_ast = self._ast_from_file(file) + self.resolve_imports(module_ast) + + return file, module_ast + + except FileNotFoundError as e: + # escape `e` from the block scope, it can make things + # easier to debug. + err = e + + try: + file = self.input_bundle.load_file(path.with_suffix(".vyi")) + assert isinstance(file, FileInput) # mypy hint + module_ast = self._ast_from_file(file) + + # no recursion yet + # self.resolve_imports(module_ast) + + return file, module_ast + + except FileNotFoundError: + pass + + try: + file = self.input_bundle.load_file(path.with_suffix(".json")) + assert isinstance(file, ABIInput) # mypy hint + return file, file.abi + except FileNotFoundError: + pass + + hint = None + if module_str.startswith("vyper.interfaces"): + hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" + + # copy search_paths, makes debugging a bit easier + search_paths = self.input_bundle.search_paths.copy() # noqa: F841 + raise ModuleNotFound(module_str, hint=hint) from err + + +def _parse_ast(file: FileInput) -> vy_ast.Module: + module_path = file.resolved_path # for error messages + try: + # try to get a relative path, to simplify the error message + cwd = Path(".") + if module_path.is_absolute(): + cwd = cwd.resolve() + module_path = module_path.relative_to(cwd) + except ValueError: + # we couldn't get a relative path (cf. docs for Path.relative_to), + # use the resolved path given to us by the InputBundle + pass + + ret = vy_ast.parse_to_ast( + file.source_code, + source_id=file.source_id, + module_path=module_path.as_posix(), + resolved_path=file.resolved_path.as_posix(), + ) + return ret + + +# convert an import to a path (without suffix) +def _import_to_path(level: int, module_str: str) -> PurePath: + base_path = "" + if level > 1: + base_path = "../" * (level - 1) + elif level == 1: + base_path = "./" + return PurePath(f"{base_path}{module_str.replace('.','/')}/") + + +# can add more, e.g. "vyper.builtins.interfaces", etc. +BUILTIN_PREFIXES = ["ethereum.ercs"] + + +# TODO: could move this to analysis/common.py or something +def _is_builtin(module_str): + return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) + + +_builtins_cache: dict[PathLike, tuple[CompilerInput, vy_ast.Module]] = {} + + +def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, vy_ast.Module]: + if not _is_builtin(module_str): # pragma: nocover + raise CompilerPanic("unreachable!") + + builtins_path = vyper.builtins.interfaces.__path__[0] + # hygiene: convert to relpath to avoid leaking user directory info + # (note Path.relative_to cannot handle absolute to relative path + # conversion, so we must use the `os` module). + builtins_path = os.path.relpath(builtins_path) + + search_path = Path(builtins_path).parent.parent.parent + # generate an input bundle just because it knows how to build paths. + input_bundle = FilesystemInputBundle([search_path]) + + # remap builtins directory -- + # ethereum/ercs => vyper/builtins/interfaces + remapped_module = module_str + if remapped_module.startswith("ethereum.ercs"): + remapped_module = remapped_module.removeprefix("ethereum.ercs") + remapped_module = vyper.builtins.interfaces.__package__ + remapped_module + + path = _import_to_path(level, remapped_module).with_suffix(".vyi") + + # builtins are globally the same, so we can safely cache them + # (it is also *correct* to cache them, so that types defined in builtins + # compare correctly using pointer-equality.) + if path in _builtins_cache: + file, module_t = _builtins_cache[path] + return file, module_t.interface + + try: + file = input_bundle.load_file(path) + assert isinstance(file, FileInput) # mypy hint + except FileNotFoundError as e: + hint = None + components = module_str.split(".") + # common issue for upgrading codebases from v0.3.x to v0.4.x - + # hint: rename ERC20 to IERC20 + if components[-1].startswith("ERC"): + module_prefix = components[-1] + hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" + raise ModuleNotFound(module_str, hint=hint) from e + + interface_ast = _parse_ast(file) + + # no recursion needed since builtins don't have any imports + + _builtins_cache[path] = file, interface_ast + return file, interface_ast + + +def resolve_imports(module_ast: vy_ast.Module): + graph = _ImportGraph() + analyzer = ImportAnalyzer(graph) + analyzer.resolve_imports(module_ast) + + return analyzer diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index d05e494b80..62bee3f9fb 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,17 +1,8 @@ -import os -from pathlib import Path, PurePath +from pathlib import PurePath from typing import Any, Optional -import vyper.builtins.interfaces from vyper import ast as vy_ast -from vyper.compiler.input_bundle import ( - ABIInput, - CompilerInput, - FileInput, - FilesystemInputBundle, - InputBundle, - PathLike, -) +from vyper.compiler.input_bundle import ABIInput, CompilerInput from vyper.evm.opcodes import version_check from vyper.exceptions import ( BorrowException, @@ -25,7 +16,6 @@ InterfaceViolation, InvalidLiteral, InvalidType, - ModuleNotFound, StateAccessViolation, StructureException, UndeclaredDefinition, @@ -45,7 +35,6 @@ from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.getters import generate_public_variable_getters -from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses from vyper.semantics.analysis.utils import ( check_modifiability, @@ -61,29 +50,16 @@ from vyper.utils import OrderedSet -def analyze_module( - module_ast: vy_ast.Module, - input_bundle: InputBundle, - import_graph: ImportGraph = None, - is_interface: bool = False, -) -> ModuleT: +def analyze_module(module_ast: vy_ast.Module) -> ModuleT: """ Analyze a Vyper module AST node, recursively analyze all its imports, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ - if import_graph is None: - import_graph = ImportGraph() - - return _analyze_module_r(module_ast, input_bundle, import_graph, is_interface) + return _analyze_module_r(module_ast) -def _analyze_module_r( - module_ast: vy_ast.Module, - input_bundle: InputBundle, - import_graph: ImportGraph, - is_interface: bool = False, -): +def _analyze_module_r(module_ast: vy_ast.Module, is_interface: bool = False): if "type" in module_ast._metadata: # we don't need to analyse again, skip out assert isinstance(module_ast._metadata["type"], ModuleT) @@ -92,8 +68,8 @@ def _analyze_module_r( # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - with namespace.enter_scope(), import_graph.enter_path(module_ast): - analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) + with namespace.enter_scope(): + analyzer = ModuleAnalyzer(module_ast, namespace, is_interface) analyzer.analyze_module_body() _analyze_call_graph(module_ast) @@ -176,20 +152,14 @@ class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, - module_node: vy_ast.Module, - input_bundle: InputBundle, - namespace: Namespace, - import_graph: ImportGraph, - is_interface: bool = False, + self, module_node: vy_ast.Module, namespace: Namespace, is_interface: bool = False ) -> None: self.ast = module_node - self.input_bundle = input_bundle self.namespace = namespace - self._import_graph = import_graph self.is_interface = is_interface # keep track of imported modules to prevent duplicate imports + # TODO: move this to ImportAnalyzer self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} # keep track of exported functions to prevent duplicate exports @@ -199,6 +169,9 @@ def __init__( self.module_t: Optional[ModuleT] = None + def resolve_imports(self): + pass + def analyze_module_body(self): # generate a `ModuleT` from the top-level node # note: also validates unique method ids @@ -390,16 +363,6 @@ def validate_initialized_modules(self): err_list.raise_if_not_empty() - def _ast_from_file(self, file: FileInput) -> vy_ast.Module: - # cache ast if we have seen it before. - # this gives us the additional property of object equality on - # two ASTs produced from the same source - ast_of = self.input_bundle._cache._ast_of - if file.source_id not in ast_of: - ast_of[file.source_id] = _parse_ast(file) - - return ast_of[file.source_id] - def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) @@ -740,32 +703,17 @@ def visit_FunctionDef(self, node): self._add_exposed_function(func_t, node) def visit_Import(self, node): - # import x.y[name] as y[alias] - - alias = node.alias - - if alias is None: - alias = node.name - - # don't handle things like `import x.y` - if "." in alias: - msg = "import requires an accompanying `as` statement" - suggested_alias = node.name[node.name.rfind(".") :] - hint = f"try `import {node.name} as {suggested_alias}`" - raise StructureException(msg, node, hint=hint) - - self._add_import(node, 0, node.name, alias) + self._add_import(node) def visit_ImportFrom(self, node): - # from m.n[module] import x[name] as y[alias] - alias = node.alias or node.name - - module = node.module or "" - if module: - module += "." + self._add_import(node) - qualified_module_name = module + node.name - self._add_import(node, node.level, qualified_module_name, alias) + def _add_import(self, node: vy_ast.VyperNode, alias: str) -> None: + compiler_input, module_info = self._load_import(node, alias) + node._metadata["import_info"] = ImportInfo( + module_info, alias, qualified_module_name, compiler_input, node + ) + self.namespace[alias] = module_info def visit_InterfaceDef(self, node): interface_t = InterfaceT.from_InterfaceDef(node) @@ -777,31 +725,10 @@ def visit_StructDef(self, node): node._metadata["struct_type"] = struct_t self.namespace[node.name] = struct_t - def _add_import( - self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str - ) -> None: - compiler_input, module_info = self._load_import(node, level, qualified_module_name, alias) - node._metadata["import_info"] = ImportInfo( - module_info, alias, qualified_module_name, compiler_input, node - ) - self.namespace[alias] = module_info - - # load an InterfaceT or ModuleInfo from an import. - # raises FileNotFoundError - def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: - # the directory this (currently being analyzed) module is in - self_search_path = Path(self.ast.resolved_path).parent - - with self.input_bundle.poke_search_path(self_search_path): - return self._load_import_helper(node, level, module_str, alias) - def _load_import_helper( self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str ) -> tuple[CompilerInput, Any]: - if _is_builtin(module_str): - return _load_builtin_import(level, module_str) - - path = _import_to_path(level, module_str) + path = node._metadata["path"] # this could conceivably be in the ImportGraph but no need at this point if path in self._imported_modules: @@ -810,156 +737,28 @@ def _load_import_helper( self._imported_modules[path] = node - err = None - - try: - path_vy = path.with_suffix(".vy") - file = self.input_bundle.load_file(path_vy) - assert isinstance(file, FileInput) # mypy hint - - module_ast = self._ast_from_file(file) - + if path.suffix == "vy": + file = node._metadata["compiler_input"] + module_ast = node._metadata["ast"] with override_global_namespace(Namespace()): - module_t = _analyze_module_r( - module_ast, - self.input_bundle, - import_graph=self._import_graph, - is_interface=False, - ) + module_t = _analyze_module_r(module_ast, is_interface=False) return file, ModuleInfo(module_t, alias) - except FileNotFoundError as e: - # escape `e` from the block scope, it can make things - # easier to debug. - err = e - - try: - file = self.input_bundle.load_file(path.with_suffix(".vyi")) - assert isinstance(file, FileInput) # mypy hint - module_ast = self._ast_from_file(file) - + if path.suffix == "vyi": + file = node._metadata["compiler_input"] + module_ast = node._metadata["ast"] with override_global_namespace(Namespace()): - _analyze_module_r( - module_ast, - self.input_bundle, - import_graph=self._import_graph, - is_interface=True, - ) + _analyze_module_r(module_ast, is_interface=True) module_t = module_ast._metadata["type"] + # TODO: return the whole module return file, module_t.interface - except FileNotFoundError: - pass - - try: - file = self.input_bundle.load_file(path.with_suffix(".json")) + if path.suffix == "json": + file = node._metadata["compiler_input"] + module_ast = node._metadata["ast"] assert isinstance(file, ABIInput) # mypy hint return file, InterfaceT.from_json_abi(str(file.path), file.abi) - except FileNotFoundError: - pass - - hint = None - if module_str.startswith("vyper.interfaces"): - hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" - - # copy search_paths, makes debugging a bit easier - search_paths = self.input_bundle.search_paths.copy() # noqa: F841 - raise ModuleNotFound(module_str, hint=hint) from err - - -def _parse_ast(file: FileInput) -> vy_ast.Module: - module_path = file.resolved_path # for error messages - try: - # try to get a relative path, to simplify the error message - cwd = Path(".") - if module_path.is_absolute(): - cwd = cwd.resolve() - module_path = module_path.relative_to(cwd) - except ValueError: - # we couldn't get a relative path (cf. docs for Path.relative_to), - # use the resolved path given to us by the InputBundle - pass - - ret = vy_ast.parse_to_ast( - file.source_code, - source_id=file.source_id, - module_path=module_path.as_posix(), - resolved_path=file.resolved_path.as_posix(), - ) - return ret - - -# convert an import to a path (without suffix) -def _import_to_path(level: int, module_str: str) -> PurePath: - base_path = "" - if level > 1: - base_path = "../" * (level - 1) - elif level == 1: - base_path = "./" - return PurePath(f"{base_path}{module_str.replace('.','/')}/") - - -# can add more, e.g. "vyper.builtins.interfaces", etc. -BUILTIN_PREFIXES = ["ethereum.ercs"] - - -# TODO: could move this to analysis/common.py or something -def _is_builtin(module_str): - return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) - - -_builtins_cache: dict[PathLike, tuple[CompilerInput, ModuleT]] = {} - - -def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, InterfaceT]: - if not _is_builtin(module_str): # pragma: nocover - raise CompilerPanic("unreachable!") - - builtins_path = vyper.builtins.interfaces.__path__[0] - # hygiene: convert to relpath to avoid leaking user directory info - # (note Path.relative_to cannot handle absolute to relative path - # conversion, so we must use the `os` module). - builtins_path = os.path.relpath(builtins_path) - - search_path = Path(builtins_path).parent.parent.parent - # generate an input bundle just because it knows how to build paths. - input_bundle = FilesystemInputBundle([search_path]) - - # remap builtins directory -- - # ethereum/ercs => vyper/builtins/interfaces - remapped_module = module_str - if remapped_module.startswith("ethereum.ercs"): - remapped_module = remapped_module.removeprefix("ethereum.ercs") - remapped_module = vyper.builtins.interfaces.__package__ + remapped_module - - path = _import_to_path(level, remapped_module).with_suffix(".vyi") - - # builtins are globally the same, so we can safely cache them - # (it is also *correct* to cache them, so that types defined in builtins - # compare correctly using pointer-equality.) - if path in _builtins_cache: - file, module_t = _builtins_cache[path] - return file, module_t.interface - - try: - file = input_bundle.load_file(path) - assert isinstance(file, FileInput) # mypy hint - except FileNotFoundError as e: - hint = None - components = module_str.split(".") - # common issue for upgrading codebases from v0.3.x to v0.4.x - - # hint: rename ERC20 to IERC20 - if components[-1].startswith("ERC"): - module_prefix = components[-1] - hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" - raise ModuleNotFound(module_str, hint=hint) from e - - interface_ast = _parse_ast(file) - - with override_global_namespace(Namespace()): - module_t = _analyze_module_r(interface_ast, input_bundle, ImportGraph(), is_interface=True) - _builtins_cache[path] = file, module_t - return file, module_t.interface + raise CompilerPanic("unreachable") # pragma: nocover From dc5c0e8dff83386634a6bb1dd324fbd256cafabe Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 08:21:08 -0400 Subject: [PATCH 02/14] modify ImportInfo --- vyper/semantics/analysis/base.py | 7 ++++--- vyper/semantics/analysis/imports.py | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 65bc8df3ab..93c159f322 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -121,11 +121,13 @@ def __hash__(self): @dataclass(frozen=True) class ImportInfo(AnalysisResult): - typ: Union[ModuleInfo, "InterfaceT"] alias: str # the name in the namespace qualified_module_name: str # for error messages compiler_input: CompilerInput # to recover file info for ast export - node: vy_ast.VyperNode + parsed: Any # (json) abi | AST + + # TODO: is this field used? + node: vy_ast._ImportStmt # the importing node def to_dict(self): ret = {"alias": self.alias, "qualified_module_name": self.qualified_module_name} @@ -137,7 +139,6 @@ def to_dict(self): return ret - # analysis result of InitializesDecl @dataclass class InitializesInfo(AnalysisResult): diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 9fdddf8266..f8f8b3222e 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -117,16 +117,6 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia with self.input_bundle.poke_search_path(self_search_path): return self._load_import_helper(node, level, module_str, alias) - def _ast_from_file(self, file: FileInput) -> vy_ast.Module: - # cache ast if we have seen it before. - # this gives us the additional property of object equality on - # two ASTs produced from the same source - ast_of = self._ast_of - if file.source_id not in ast_of: - ast_of[file.source_id] = _parse_ast(file) - - return ast_of[file.source_id] - def _load_import_helper( self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str ) -> tuple[CompilerInput, Any]: @@ -186,6 +176,16 @@ def _load_import_helper( search_paths = self.input_bundle.search_paths.copy() # noqa: F841 raise ModuleNotFound(module_str, hint=hint) from err + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_ast(file) + + return ast_of[file.source_id] + def _parse_ast(file: FileInput) -> vy_ast.Module: module_path = file.resolved_path # for error messages From 387f50d8fe4dfdc435745046bdb8f5f23b6bbc1f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 09:22:53 -0400 Subject: [PATCH 03/14] wip --- vyper/compiler/output_bundle.py | 2 +- vyper/compiler/phases.py | 7 ++- vyper/semantics/analysis/base.py | 15 ++++-- vyper/semantics/analysis/imports.py | 52 ++++++++++++++------ vyper/semantics/analysis/module.py | 73 ++++++++++++----------------- 5 files changed, 83 insertions(+), 66 deletions(-) diff --git a/vyper/compiler/output_bundle.py b/vyper/compiler/output_bundle.py index 92494e3a70..cfc8c18460 100644 --- a/vyper/compiler/output_bundle.py +++ b/vyper/compiler/output_bundle.py @@ -12,7 +12,7 @@ from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.exceptions import CompilerPanic -from vyper.semantics.analysis.module import _is_builtin +from vyper.semantics.analysis.imports import _is_builtin from vyper.utils import get_long_version # data structures and routines for constructing "output bundles", diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 147af24d67..3ae9de8741 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -13,6 +13,7 @@ from vyper.ir import compile_ir, optimizer from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.analysis.data_positions import generate_layout_export +from vyper.semantics.analysis.imports import resolve_imports from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -283,8 +284,10 @@ def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundl """ vyper_module = copy.deepcopy(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: analyze_module does type inference on the AST - analyze_module(vyper_module, input_bundle) + # TODO: move this to its own pass + resolve_imports(vyper_module, input_bundle) + # note: analyze_module does type inference on the AST + analyze_module(vyper_module) return vyper_module diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 93c159f322..4030b6a47d 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,7 +1,7 @@ import enum from dataclasses import dataclass from functools import cached_property -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional from vyper import ast as vy_ast from vyper.compiler.input_bundle import CompilerInput @@ -13,7 +13,7 @@ if TYPE_CHECKING: from vyper.semantics.types.function import ContractFunctionT - from vyper.semantics.types.module import InterfaceT, ModuleT + from vyper.semantics.types.module import ModuleT class FunctionVisibility(StringEnum): @@ -119,15 +119,19 @@ def __hash__(self): return hash(id(self.module_t)) -@dataclass(frozen=True) +@dataclass class ImportInfo(AnalysisResult): alias: str # the name in the namespace qualified_module_name: str # for error messages compiler_input: CompilerInput # to recover file info for ast export parsed: Any # (json) abi | AST + _typ: Any = None # type to be filled in during analysis - # TODO: is this field used? - node: vy_ast._ImportStmt # the importing node + @property + def typ(self): + if self._typ is None: # pragma: nocover + raise CompilerPanic("unreachable!") + return self._typ def to_dict(self): ret = {"alias": self.alias, "qualified_module_name": self.qualified_module_name} @@ -139,6 +143,7 @@ def to_dict(self): return ret + # analysis result of InitializesDecl @dataclass class InitializesInfo(AnalysisResult): diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index f8f8b3222e..070c20f649 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -4,7 +4,7 @@ from pathlib import Path, PurePath from typing import Any, Iterator -import vyper +import vyper.builtins.interfaces from vyper import ast as vy_ast from vyper.compiler.input_bundle import ( ABIInput, @@ -21,9 +21,13 @@ ModuleNotFound, StructureException, ) +from vyper.semantics.analysis.base import ImportInfo """ -collect import statements and validate the import graph +collect import statements and validate the import graph. +this module is separated into its own pass so that we can resolve the import +graph quickly (without doing semantic analysis) and for cleanliness, to +segregate the I/O portion of semantic analysis into its own pass. """ @@ -32,17 +36,31 @@ class _ImportGraph: # the current path in the import graph traversal _path: list[vy_ast.Module] = field(default_factory=list) + # stack of dicts, each item in the stack is a dict keeping + # track of imports in the current module + _imports: list[dict] = field(default_factory=list) + + @property + def imported_modules(self): + return self._imports[-1] + + @property + def current_module(self): + return self._path[-1] + def push_path(self, module_ast: vy_ast.Module) -> None: if module_ast in self._path: cycle = self._path + [module_ast] raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) self._path.append(module_ast) + self._imports.append({}) def pop_path(self, expected: vy_ast.Module) -> None: popped = self._path.pop() if expected != popped: raise CompilerPanic("unreachable") + self._imports.pop() @contextlib.contextmanager def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: @@ -55,8 +73,9 @@ def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: class ImportAnalyzer: def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): + self.input_bundle = input_bundle self.graph = graph - self._ast_of: dict[PathLike, vy_ast.Module] = {} + self._ast_of: dict[int, vy_ast.Module] = {} self.integrity_sum = None @@ -64,6 +83,10 @@ def resolve_imports(self, module_ast: vy_ast.Module): self._resolve_imports_r(module_ast) self.integrity_sum = self._calculate_integrity_sum(module_ast) + def _calculate_integrity_sum(self, module_ast: vy_ast.Module): + # TODO: stub + pass + def _resolve_imports_r(self, module_ast: vy_ast.Module): with self.graph.enter_path(module_ast): for node in module_ast.body: @@ -104,15 +127,16 @@ def _add_import( self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str ) -> None: compiler_input, ast = self._load_import(node, level, qualified_module_name, alias) - node._metadata["compiler_input"] = compiler_input - node._metadata["imported_ast"] = ast - node._metadata["alias"] = alias + node._metadata["import_info"] = ImportInfo( + alias, qualified_module_name, compiler_input, ast + ) # load an InterfaceT or ModuleInfo from an import. # raises FileNotFoundError def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: # the directory this (currently being analyzed) module is in - self_search_path = Path(self.ast.resolved_path).parent + ast = self.graph.current_module + self_search_path = Path(ast.resolved_path).parent with self.input_bundle.poke_search_path(self_search_path): return self._load_import_helper(node, level, module_str, alias) @@ -125,11 +149,11 @@ def _load_import_helper( path = _import_to_path(level, module_str) - if path in self.graph._imported_modules: - previous_import_stmt = self._imported_modules[path] + if path in self.graph.imported_modules: + previous_import_stmt = self.graph.imported_modules[path] raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) - self._imported_modules[path] = node + self.graph.imported_modules[path] = node err = None @@ -258,8 +282,8 @@ def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, vy # (it is also *correct* to cache them, so that types defined in builtins # compare correctly using pointer-equality.) if path in _builtins_cache: - file, module_t = _builtins_cache[path] - return file, module_t.interface + file, ast = _builtins_cache[path] + return file, ast try: file = input_bundle.load_file(path) @@ -282,9 +306,9 @@ def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, vy return file, interface_ast -def resolve_imports(module_ast: vy_ast.Module): +def resolve_imports(module_ast: vy_ast.Module, input_bundle: InputBundle): graph = _ImportGraph() - analyzer = ImportAnalyzer(graph) + analyzer = ImportAnalyzer(input_bundle, graph) analyzer.resolve_imports(module_ast) return analyzer diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 62bee3f9fb..fabecd55d2 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -2,13 +2,11 @@ from typing import Any, Optional from vyper import ast as vy_ast -from vyper.compiler.input_bundle import ABIInput, CompilerInput from vyper.evm.opcodes import version_check from vyper.exceptions import ( BorrowException, CallViolation, CompilerPanic, - DuplicateImport, EvmVersionException, ExceptionList, ImmutableViolation, @@ -708,57 +706,44 @@ def visit_Import(self, node): def visit_ImportFrom(self, node): self._add_import(node) - def _add_import(self, node: vy_ast.VyperNode, alias: str) -> None: - compiler_input, module_info = self._load_import(node, alias) - node._metadata["import_info"] = ImportInfo( - module_info, alias, qualified_module_name, compiler_input, node - ) - self.namespace[alias] = module_info - - def visit_InterfaceDef(self, node): - interface_t = InterfaceT.from_InterfaceDef(node) - node._metadata["interface_type"] = interface_t - self.namespace[node.name] = interface_t - - def visit_StructDef(self, node): - struct_t = StructT.from_StructDef(node) - node._metadata["struct_type"] = struct_t - self.namespace[node.name] = struct_t - - def _load_import_helper( - self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str - ) -> tuple[CompilerInput, Any]: - path = node._metadata["path"] + def _add_import(self, node: vy_ast.VyperNode) -> None: + import_info = node._metadata["import_info"] + # similar structure to import analyzer + module_info = self._load_import(node, import_info) - # this could conceivably be in the ImportGraph but no need at this point - if path in self._imported_modules: - previous_import_stmt = self._imported_modules[path] - raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + import_info._typ = module_info - self._imported_modules[path] = node + self.namespace[import_info.alias] = module_info - if path.suffix == "vy": - file = node._metadata["compiler_input"] - module_ast = node._metadata["ast"] + def _load_import(self, node: vy_ast.VyperNode, import_info: ImportInfo) -> Any: + path = import_info.compiler_input.path + if path.suffix == ".vy": + module_ast = import_info.parsed with override_global_namespace(Namespace()): module_t = _analyze_module_r(module_ast, is_interface=False) + return ModuleInfo(module_t, import_info.alias) - return file, ModuleInfo(module_t, alias) - - if path.suffix == "vyi": - file = node._metadata["compiler_input"] - module_ast = node._metadata["ast"] + if path.suffix == ".vyi": + module_ast = import_info.parsed with override_global_namespace(Namespace()): - _analyze_module_r(module_ast, is_interface=True) - module_t = module_ast._metadata["type"] + module_t = _analyze_module_r(module_ast, is_interface=True) # TODO: return the whole module - return file, module_t.interface + return module_t.interface - if path.suffix == "json": - file = node._metadata["compiler_input"] - module_ast = node._metadata["ast"] - assert isinstance(file, ABIInput) # mypy hint - return file, InterfaceT.from_json_abi(str(file.path), file.abi) + if path.suffix == ".json": + abi = import_info.parsed + path = import_info.compiler_input.path + return InterfaceT.from_json_abi(str(path), abi) raise CompilerPanic("unreachable") # pragma: nocover + + def visit_InterfaceDef(self, node): + interface_t = InterfaceT.from_InterfaceDef(node) + node._metadata["interface_type"] = interface_t + self.namespace[node.name] = interface_t + + def visit_StructDef(self, node): + struct_t = StructT.from_StructDef(node) + node._metadata["struct_type"] = struct_t + self.namespace[node.name] = struct_t From 9550097adc259966e93817c14bb62df4f2540813 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 10:33:39 -0400 Subject: [PATCH 04/14] make integrity sum fast --- vyper/compiler/output.py | 2 +- vyper/compiler/phases.py | 18 ++++++++++++------ vyper/semantics/analysis/imports.py | 18 ++++++++++++++---- vyper/semantics/types/module.py | 15 --------------- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 577afd3822..f9aa4bd23d 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -76,7 +76,7 @@ def build_archive_b64(compiler_data: CompilerData) -> str: def build_integrity(compiler_data: CompilerData) -> str: - return compiler_data.compilation_target._metadata["type"].integrity_sum + return compiler_data.resolved_imports.integrity_sum def build_external_interface_output(compiler_data: CompilerData) -> str: diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 3ae9de8741..7b9b18e909 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -146,9 +146,19 @@ def vyper_module(self): _, ast = self._generate_ast return ast + @cached_property + def _resolve_imports(self): + vyper_module = copy.deepcopy(self.vyper_module) + with self.input_bundle.search_path(Path(vyper_module.resolved_path).parent): + return vyper_module, resolve_imports(vyper_module, self.input_bundle) + + @cached_property + def resolved_imports(self): + return self._resolve_imports[1] + @cached_property def _annotate(self) -> tuple[natspec.NatspecOutput, vy_ast.Module]: - module = generate_annotated_ast(self.vyper_module, self.input_bundle) + module = generate_annotated_ast(self._resolve_imports[0]) nspec = natspec.parse_natspec(module) return nspec, module @@ -268,7 +278,7 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: +def generate_annotated_ast(vyper_module: vy_ast.Module) -> vy_ast.Module: """ Validates and annotates the Vyper AST. @@ -282,10 +292,6 @@ def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundl vy_ast.Module Annotated Vyper AST """ - vyper_module = copy.deepcopy(vyper_module) - with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # TODO: move this to its own pass - resolve_imports(vyper_module, input_bundle) # note: analyze_module does type inference on the AST analyze_module(vyper_module) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 070c20f649..5472074f4a 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -1,5 +1,6 @@ import contextlib import os +from vyper.utils import sha256sum from dataclasses import dataclass, field from pathlib import Path, PurePath from typing import Any, Iterator @@ -81,11 +82,20 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): def resolve_imports(self, module_ast: vy_ast.Module): self._resolve_imports_r(module_ast) - self.integrity_sum = self._calculate_integrity_sum(module_ast) + self.integrity_sum = self._calculate_integrity_sum_r(module_ast) - def _calculate_integrity_sum(self, module_ast: vy_ast.Module): - # TODO: stub - pass + def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module): + acc = [sha256sum(module_ast.full_source_code)] + for s in module_ast.get_children(vy_ast._ImportStmt): + info = s._metadata["import_info"] + + if info.compiler_input.path.suffix in (".vyi", ".json"): + # NOTE: this needs to be redone if interfaces can import other interfaces + acc.append(info.compiler_input.sha256sum) + else: + acc.append(self._calculate_integrity_sum_r(info.parsed)) + + return sha256sum("".join(acc)) def _resolve_imports_r(self, module_ast: vy_ast.Module): with self.graph.enter_path(module_ast): diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ba72842c65..2b38cdf3bd 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -437,21 +437,6 @@ def reachable_imports(self) -> list["ImportInfo"]: return ret - @cached_property - def integrity_sum(self) -> str: - acc = [sha256sum(self._module.full_source_code)] - for s in self.import_stmts: - info = s._metadata["import_info"] - - if isinstance(info.typ, InterfaceT): - # NOTE: this needs to be redone if interfaces can import other interfaces - acc.append(info.compiler_input.sha256sum) - else: - assert isinstance(info.typ.typ, ModuleT) - acc.append(info.typ.typ.integrity_sum) - - return sha256sum("".join(acc)) - def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: for s in self.imported_modules.values(): if s.module_t == needle: From f8f7529c38b426cc98984399be2cbf2dc3395338 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 11:01:15 -0400 Subject: [PATCH 05/14] wip --- vyper/compiler/output_bundle.py | 2 +- vyper/compiler/phases.py | 23 ++--------------------- vyper/semantics/analysis/module.py | 7 ------- 3 files changed, 3 insertions(+), 29 deletions(-) diff --git a/vyper/compiler/output_bundle.py b/vyper/compiler/output_bundle.py index cfc8c18460..7c989b19ea 100644 --- a/vyper/compiler/output_bundle.py +++ b/vyper/compiler/output_bundle.py @@ -159,7 +159,7 @@ def write(self): self.write_compilation_target([self.bundle.compilation_target_path]) self.write_search_paths(self.bundle.used_search_paths) self.write_settings(self.compiler_data.original_settings) - self.write_integrity(self.bundle.compilation_target.integrity_sum) + self.write_integrity(self.compiler_data.resolved_imports.integrity_sum) self.write_sources(self.bundle.compiler_inputs) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 7b9b18e909..0017f805f7 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -158,7 +158,8 @@ def resolved_imports(self): @cached_property def _annotate(self) -> tuple[natspec.NatspecOutput, vy_ast.Module]: - module = generate_annotated_ast(self._resolve_imports[0]) + module = self._resolve_imports[0] + analyze_module(module) nspec = natspec.parse_natspec(module) return nspec, module @@ -278,26 +279,6 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast(vyper_module: vy_ast.Module) -> vy_ast.Module: - """ - Validates and annotates the Vyper AST. - - Arguments - --------- - vyper_module : vy_ast.Module - Top-level Vyper AST node - - Returns - ------- - vy_ast.Module - Annotated Vyper AST - """ - # note: analyze_module does type inference on the AST - analyze_module(vyper_module) - - return vyper_module - - def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index fabecd55d2..16841e6948 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -156,10 +156,6 @@ def __init__( self.namespace = namespace self.is_interface = is_interface - # keep track of imported modules to prevent duplicate imports - # TODO: move this to ImportAnalyzer - self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} - # keep track of exported functions to prevent duplicate exports self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {} @@ -167,9 +163,6 @@ def __init__( self.module_t: Optional[ModuleT] = None - def resolve_imports(self): - pass - def analyze_module_body(self): # generate a `ModuleT` from the top-level node # note: also validates unique method ids From c4124022417a1e7f913f5958b7b1fd4effca4631 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 11:40:39 -0400 Subject: [PATCH 06/14] fix lint --- vyper/semantics/analysis/imports.py | 4 ++-- vyper/semantics/analysis/module.py | 1 - vyper/semantics/types/module.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 5472074f4a..01ce3ef0b8 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -1,6 +1,5 @@ import contextlib import os -from vyper.utils import sha256sum from dataclasses import dataclass, field from pathlib import Path, PurePath from typing import Any, Iterator @@ -23,6 +22,7 @@ StructureException, ) from vyper.semantics.analysis.base import ImportInfo +from vyper.utils import sha256sum """ collect import statements and validate the import graph. @@ -86,7 +86,7 @@ def resolve_imports(self, module_ast: vy_ast.Module): def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module): acc = [sha256sum(module_ast.full_source_code)] - for s in module_ast.get_children(vy_ast._ImportStmt): + for s in module_ast.get_children((vy_ast.Import, vy_ast.ImportFrom)): info = s._metadata["import_info"] if info.compiler_input.path.suffix in (".vyi", ".json"): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 16841e6948..b7945ab8cf 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,4 +1,3 @@ -from pathlib import PurePath from typing import Any, Optional from vyper import ast as vy_ast diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 2b38cdf3bd..d6cc50a2ea 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -22,7 +22,7 @@ from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType -from vyper.utils import OrderedSet, sha256sum +from vyper.utils import OrderedSet if TYPE_CHECKING: from vyper.semantics.analysis.base import ImportInfo, ModuleInfo From 0716f0aafb6c146e6e4e8fb72fc5b249b4c6e6a2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 8 Sep 2024 11:47:24 -0400 Subject: [PATCH 07/14] keep track of visited --- vyper/semantics/analysis/imports.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 01ce3ef0b8..d04dd61ac9 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -78,6 +78,8 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): self.graph = graph self._ast_of: dict[int, vy_ast.Module] = {} + self.seen = set() + self.integrity_sum = None def resolve_imports(self, module_ast: vy_ast.Module): @@ -98,12 +100,15 @@ def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module): return sha256sum("".join(acc)) def _resolve_imports_r(self, module_ast: vy_ast.Module): + if id(module_ast) in self.seen: + return with self.graph.enter_path(module_ast): for node in module_ast.body: if isinstance(node, vy_ast.Import): self._handle_Import(node) elif isinstance(node, vy_ast.ImportFrom): self._handle_ImportFrom(node) + self.seen.add(id(module_ast)) def _handle_Import(self, node: vy_ast.Import): # import x.y[name] as y[alias] From 1ba7d67b56c0e632329c572db509e3e2be2d61ac Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 5 Oct 2024 14:35:42 -0400 Subject: [PATCH 08/14] fix lint --- vyper/semantics/analysis/imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 8fc3c8741b..6c06b767f4 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -77,7 +77,7 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): self.graph = graph self._ast_of: dict[int, vy_ast.Module] = {} - self.seen = set() + self.seen: set[int] = set() self.integrity_sum = None From dadaa7a087f5e0360fdf1d4c82ae84e4a70d8399 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 5 Oct 2024 14:40:51 -0400 Subject: [PATCH 09/14] fix some tests --- tests/conftest.py | 8 +-- .../codegen/types/numbers/test_decimals.py | 2 +- tests/unit/ast/nodes/test_hex.py | 4 +- .../semantics/analysis/test_array_index.py | 20 +++---- .../analysis/test_cyclic_function_calls.py | 24 ++++----- .../unit/semantics/analysis/test_for_loop.py | 52 +++++++++---------- vyper/compiler/phases.py | 29 ++++++----- 7 files changed, 67 insertions(+), 72 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 31c72246bd..76ebc2df22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from tests.utils import working_directory from vyper import compiler from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle from vyper.compiler.settings import OptimizationLevel, Settings, set_global_settings from vyper.exceptions import EvmVersionException from vyper.ir import compile_ir, optimizer @@ -166,12 +166,6 @@ def fn(sources_dict): return fn -# for tests which just need an input bundle, doesn't matter what it is -@pytest.fixture -def dummy_input_bundle(): - return InputBundle([]) - - @pytest.fixture(scope="module") def gas_limit(): # set absurdly high gas limit so that london basefee never adjusts diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 36c14f804d..ad8bf74b0d 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -299,7 +299,7 @@ def foo(): compile_code(code) -def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): +def test_replace_decimal_nested_intermediate_underflow(): code = """ @external def foo(): diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 7168defa99..6d82b1d2ab 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -40,7 +40,7 @@ def foo(): @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code, dummy_input_bundle): +def test_invalid_checksum(code): with pytest.raises(InvalidLiteral): vyper_module = vy_ast.parse_to_ast(code) - semantics.analyze_module(vyper_module, dummy_input_bundle) + semantics.analyze_module(vyper_module) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index b5bf86494d..aa9a702be3 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value, dummy_input_bundle): +def test_type_mismatch(namespace, value): code = f""" a: uint256[3] @@ -22,11 +22,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value, dummy_input_bundle): +def test_invalid_literal(namespace, value): code = f""" a: uint256[3] @@ -37,11 +37,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value, dummy_input_bundle): +def test_out_of_bounds(namespace, value): code = f""" a: uint256[3] @@ -52,11 +52,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value, dummy_input_bundle): +def test_undeclared_definition(namespace, value): code = f""" a: uint256[3] @@ -67,11 +67,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value, dummy_input_bundle): +def test_invalid_reference(namespace, value): code = f""" a: uint256[3] @@ -82,4 +82,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 406adc00ab..da2e63c5fc 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -5,7 +5,7 @@ from vyper.semantics.analysis import analyze_module -def test_self_function_call(dummy_input_bundle): +def test_self_function_call(): code = """ @internal def foo(): @@ -13,12 +13,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value.message == "Contract contains cyclic function call: foo -> foo" -def test_self_function_call2(dummy_input_bundle): +def test_self_function_call2(): code = """ @external def foo(): @@ -30,12 +30,12 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value.message == "Contract contains cyclic function call: foo -> bar -> bar" -def test_cyclic_function_call(dummy_input_bundle): +def test_cyclic_function_call(): code = """ @internal def foo(): @@ -47,12 +47,12 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value.message == "Contract contains cyclic function call: foo -> bar -> foo" -def test_multi_cyclic_function_call(dummy_input_bundle): +def test_multi_cyclic_function_call(): code = """ @internal def foo(): @@ -72,14 +72,14 @@ def potato(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> foo" assert e.value.message == expected_message -def test_multi_cyclic_function_call2(dummy_input_bundle): +def test_multi_cyclic_function_call2(): code = """ @internal def foo(): @@ -99,14 +99,14 @@ def potato(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> bar" assert e.value.message == expected_message -def test_global_ann_assign_callable_no_crash(dummy_input_bundle): +def test_global_ann_assign_callable_no_crash(): code = """ balanceOf: public(HashMap[address, uint256]) @@ -116,5 +116,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index d7d4f7083b..810ff0a8b9 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -5,7 +5,7 @@ from vyper.semantics.analysis import analyze_module -def test_modify_iterator_function_outside_loop(dummy_input_bundle): +def test_modify_iterator_function_outside_loop(): code = """ a: uint256[3] @@ -21,10 +21,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_pass_memory_var_to_other_function(dummy_input_bundle): +def test_pass_memory_var_to_other_function(): code = """ @internal @@ -41,10 +41,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator(dummy_input_bundle): +def test_modify_iterator(): code = """ a: uint256[3] @@ -56,10 +56,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_bad_keywords(dummy_input_bundle): +def test_bad_keywords(): code = """ @internal @@ -70,10 +70,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_bad_bound(dummy_input_bundle): +def test_bad_bound(): code = """ @internal @@ -84,10 +84,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_function_call(dummy_input_bundle): +def test_modify_iterator_function_call(): code = """ a: uint256[3] @@ -103,10 +103,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_recursive_function_call(dummy_input_bundle): +def test_modify_iterator_recursive_function_call(): code = """ a: uint256[3] @@ -126,10 +126,10 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): +def test_modify_iterator_recursive_function_call_topsort(): # test the analysis works no matter the order of functions code = """ a: uint256[3] @@ -149,12 +149,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `a`" -def test_modify_iterator_through_struct(dummy_input_bundle): +def test_modify_iterator_through_struct(): # GH issue 3429 code = """ struct A: @@ -170,12 +170,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `a`" -def test_modify_iterator_complex_expr(dummy_input_bundle): +def test_modify_iterator_complex_expr(): # GH issue 3429 # avoid false positive! code = """ @@ -189,10 +189,10 @@ def foo(): self.b[self.a[1]] = i """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_siblings(dummy_input_bundle): +def test_modify_iterator_siblings(): # test we can modify siblings in an access tree code = """ struct Foo: @@ -207,10 +207,10 @@ def foo(): self.f.b += i """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_subscript_barrier(dummy_input_bundle): +def test_modify_subscript_barrier(): # test that Subscript nodes are a barrier for analysis code = """ struct Foo: @@ -229,7 +229,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `b`" @@ -269,7 +269,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(code, dummy_input_bundle): +def test_iterator_type_inference_checker(code): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index db06efbb83..4f115bf4d0 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -154,7 +154,20 @@ def _resolve_imports(self): @cached_property def resolved_imports(self): - return self._resolve_imports[1] + imports = self._resolve_imports[1] + + expected = self.expected_integrity_sum + + if expected is not None and imports.integrity_sum != expected: + # warn for now. strict/relaxed mode was considered but it costs + # interface and testing complexity to add another feature flag. + vyper_warn( + f"Mismatched integrity sum! Expected {expected}" + f" but got {imports.integrity_sum}." + " (This likely indicates a corrupted archive)" + ) + + return imports @cached_property def _annotate(self) -> tuple[natspec.NatspecOutput, vy_ast.Module]: @@ -179,17 +192,6 @@ def compilation_target(self): """ module_t = self.annotated_vyper_module._metadata["type"] - expected = self.expected_integrity_sum - - if expected is not None and module_t.integrity_sum != expected: - # warn for now. strict/relaxed mode was considered but it costs - # interface and testing complexity to add another feature flag. - vyper_warn( - f"Mismatched integrity sum! Expected {expected}" - f" but got {module_t.integrity_sum}." - " (This likely indicates a corrupted archive)" - ) - validate_compilation_target(module_t) return self.annotated_vyper_module @@ -263,8 +265,7 @@ def assembly_runtime(self) -> list: def bytecode(self) -> bytes: metadata = None if not self.no_bytecode_metadata: - module_t = self.compilation_target._metadata["type"] - metadata = bytes.fromhex(module_t.integrity_sum) + metadata = bytes.fromhex(self.resolved_imports.integrity_sum) return generate_bytecode(self.assembly, compiler_metadata=metadata) @cached_property From 1206a8defe992df4484c27802e201f264dc182c0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 7 Oct 2024 10:46:15 -0400 Subject: [PATCH 10/14] pretty up code for alias --- vyper/semantics/analysis/imports.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 6c06b767f4..acdd703f74 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -128,7 +128,11 @@ def _handle_Import(self, node: vy_ast.Import): def _handle_ImportFrom(self, node: vy_ast.ImportFrom): # from m.n[module] import x[name] as y[alias] - alias = node.alias or node.name + + alias = node.alias + + if alias is None: + alias = node.name module = node.module or "" if module: From 159efa1ce9f0ab0f62e14ea3755e84bc40f976bf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 7 Oct 2024 10:50:50 -0400 Subject: [PATCH 11/14] update a comment --- vyper/semantics/analysis/imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index acdd703f74..be1f2da312 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -195,7 +195,7 @@ def _load_import_helper( assert isinstance(file, FileInput) # mypy hint module_ast = self._ast_from_file(file) - # no recursion yet + # language does not yet allow recursion for vyi files # self.resolve_imports(module_ast) return file, module_ast From 0e6290db6ad2132e736cf73c73e30ac5c66ed300 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 7 Oct 2024 10:52:35 -0400 Subject: [PATCH 12/14] add another comment --- vyper/semantics/analysis/module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index b7945ab8cf..9761acddca 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -720,7 +720,9 @@ def _load_import(self, node: vy_ast.VyperNode, import_info: ImportInfo) -> Any: with override_global_namespace(Namespace()): module_t = _analyze_module_r(module_ast, is_interface=True) - # TODO: return the whole module + # NOTE: might be cleaner to return the whole module, so we + # have a ModuleInfo, that way we don't need to have different + # code paths for InterfaceT vs ModuleInfo return module_t.interface if path.suffix == ".json": From f94eebc61771e7244ffa682744ad8fdb67ef9a06 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 7 Oct 2024 10:53:17 -0400 Subject: [PATCH 13/14] add another comment --- vyper/compiler/phases.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4f115bf4d0..d9b6b13b48 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -148,6 +148,7 @@ def vyper_module(self): @cached_property def _resolve_imports(self): + # deepcopy so as to not interfere with `-f ast` output vyper_module = copy.deepcopy(self.vyper_module) with self.input_bundle.search_path(Path(vyper_module.resolved_path).parent): return vyper_module, resolve_imports(vyper_module, self.input_bundle) From b5e417f51836be4f93957e3f4403357558fde2ec Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 7 Oct 2024 10:54:32 -0400 Subject: [PATCH 14/14] remove unused variable --- vyper/semantics/analysis/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 9761acddca..8a2beb61e6 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -701,13 +701,13 @@ def visit_ImportFrom(self, node): def _add_import(self, node: vy_ast.VyperNode) -> None: import_info = node._metadata["import_info"] # similar structure to import analyzer - module_info = self._load_import(node, import_info) + module_info = self._load_import(import_info) import_info._typ = module_info self.namespace[import_info.alias] = module_info - def _load_import(self, node: vy_ast.VyperNode, import_info: ImportInfo) -> Any: + def _load_import(self, import_info: ImportInfo) -> Any: path = import_info.compiler_input.path if path.suffix == ".vy": module_ast = import_info.parsed