Skip to content

Commit

Permalink
fix[lang]: allow type expressions inside pure functions (#3906)
Browse files Browse the repository at this point in the history
20432c5 introduced a regression where type expressions like the
following would raise a compiler error instead of successfully
compiling:
```
@pure
def f():
    convert(..., uint256)  # raises `not a variable or literal: 'uint256'`
```

the reason is because `get_expr_info` is called on `uint256`, which
is not a regular expr. this commit introduces a fastpath return to
address the issue. longer-term, we should generalize the rules in
`vyper/semantics/analysis/local.py` so that AST traversal does not
progress into type expressions.
  • Loading branch information
charles-cooper authored Apr 3, 2024
1 parent 9db9078 commit 45a225c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
13 changes: 13 additions & 0 deletions tests/functional/codegen/features/decorators/test_pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ def foo() -> uint256:
compile_code(code)


def test_type_in_pure(get_contract):
code = """
@pure
@external
def _convert(x: bytes32) -> uint256:
return convert(x, uint256)
"""
c = get_contract(code)
x = 123456
bs = x.to_bytes(32, "big")
assert x == c._convert(bs)


def test_invalid_conflicting_decorators():
code = """
@pure
Expand Down
9 changes: 6 additions & 3 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def _validate_msg_value_access(node: vy_ast.Attribute) -> None:
raise NonPayableViolation("msg.value is not allowed in non-payable functions", node)


def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name) -> None:
def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name, typ: VyperType) -> None:
if isinstance(typ, TYPE_T):
return

info = get_expr_info(node)

env_vars = CONSTANT_ENVIRONMENT_VARS
Expand Down Expand Up @@ -705,7 +708,7 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None:
_validate_msg_value_access(node)

if self.func and self.func.mutability == StateMutability.PURE:
_validate_pure_access(node)
_validate_pure_access(node, typ)

value_type = get_exact_type_from_node(node.value)

Expand Down Expand Up @@ -886,7 +889,7 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None:

def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None:
if self.func and self.func.mutability == StateMutability.PURE:
_validate_pure_access(node)
_validate_pure_access(node, typ)

def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None:
if isinstance(typ, TYPE_T):
Expand Down

0 comments on commit 45a225c

Please # to comment.