Skip to content

Commit

Permalink
feat[lang]: allow downcasting of bytestrings (#3832)
Browse files Browse the repository at this point in the history
this commit extends `convert()` to allow downcasting of Bytes/Strings,
i.e. converting `Bytes[20]` to `Bytes[19]`.

this improves the UX of bytestrings somewhat, since currently (prior to
this commit) there is no type-safe way to decrease the size of a
bytestring in vyper. it also prepares us a little bit for adding generic
bytestrings inside the type system (`Bytes[...]`) which can only be
user-instantiated by `convert`ing to a known length.
  • Loading branch information
charles-cooper authored Mar 12, 2024
1 parent 246f4a7 commit 9cfe7b4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
61 changes: 57 additions & 4 deletions tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import eth.codecs.abi.exceptions
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import InvalidLiteral, InvalidType, TypeMismatch
from vyper.semantics.types import AddressT, BoolT, BytesM_T, BytesT, DecimalT, IntegerT, StringT
from vyper.semantics.types.shortcuts import BYTES20_T, BYTES32_T, UINT, UINT160_T, UINT256_T
Expand Down Expand Up @@ -560,23 +561,75 @@ def foo(x: {i_typ}) -> {o_typ}:
assert_compile_failed(lambda: get_contract(code), TypeMismatch)


@pytest.mark.parametrize("typ", sorted(TEST_TYPES))
def test_bytes_too_large_cases(get_contract, assert_compile_failed, typ):
@pytest.mark.parametrize("typ", sorted(BASE_TYPES))
def test_bytes_too_large_cases(typ):
code_1 = f"""
@external
def foo(x: Bytes[33]) -> {typ}:
return convert(x, {typ})
"""
assert_compile_failed(lambda: get_contract(code_1), TypeMismatch)
with pytest.raises(TypeMismatch):
compile_code(code_1)

bytes_33 = b"1" * 33
code_2 = f"""
@external
def foo() -> {typ}:
return convert({bytes_33}, {typ})
"""
with pytest.raises(TypeMismatch):
compile_code(code_2)

assert_compile_failed(lambda: get_contract(code_2, TypeMismatch))

@pytest.mark.parametrize("cls1,cls2", itertools.product((StringT, BytesT), (StringT, BytesT)))
def test_bytestring_conversions(cls1, cls2, get_contract, tx_failed):
typ1 = cls1(33)
typ2 = cls2(32)

def bytestring(cls, string):
if cls == BytesT:
return string.encode("utf-8")
return string

code_1 = f"""
@external
def foo(x: {typ1}) -> {typ2}:
return convert(x, {typ2})
"""
c = get_contract(code_1)

for i in range(33): # inclusive 32
s = "1" * i
arg = bytestring(cls1, s)
out = bytestring(cls2, s)
assert c.foo(arg) == out

with tx_failed():
# TODO: sanity check it is convert which is reverting, not arg clamping
c.foo(bytestring(cls1, "1" * 33))

code_2_template = """
@external
def foo() -> {typ}:
return convert({arg}, {typ})
"""

# test literals
for i in range(33): # inclusive 32
s = "1" * i
arg = bytestring(cls1, s)
out = bytestring(cls2, s)
code = code_2_template.format(typ=typ2, arg=repr(arg))
if cls1 == cls2: # ex.: can't convert "" to String[32]
with pytest.raises(InvalidType):
compile_code(code)
else:
c = get_contract(code)
assert c.foo() == out

failing_code = code_2_template.format(typ=typ2, arg=bytestring(cls1, "1" * 33))
with pytest.raises(TypeMismatch):
compile_code(failing_code)


@pytest.mark.parametrize("n", range(1, 33))
Expand Down
30 changes: 19 additions & 11 deletions vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,23 +422,31 @@ def to_address(expr, arg, out_typ):
return IRnode.from_list(ret, out_typ)


# question: should we allow bytesM -> String?
@_input_types(BytesT)
def to_string(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, out_typ.maxlen)
def _cast_bytestring(expr, arg, out_typ):
# ban converting Bytes[20] to Bytes[21]
if isinstance(arg.typ, out_typ.__class__) and arg.typ.maxlen <= out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)
# can't downcast literals with known length (e.g. b"abc" to Bytes[2])
if isinstance(expr, vy_ast.Constant) and arg.typ.maxlen > out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)

ret = ["seq"]
if out_typ.maxlen < arg.typ.maxlen:
ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]])
ret.append(arg)
# NOTE: this is a pointer cast
return IRnode.from_list(arg, typ=out_typ)
return IRnode.from_list(ret, typ=out_typ, location=arg.location, encoding=arg.encoding)


@_input_types(StringT)
def to_bytes(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, out_typ.maxlen)
# question: should we allow bytesM -> String?
@_input_types(BytesT, StringT)
def to_string(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)

# TODO: more casts

# NOTE: this is a pointer cast
return IRnode.from_list(arg, typ=out_typ)
@_input_types(StringT, BytesT)
def to_bytes(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)


@_input_types(IntegerT)
Expand Down

0 comments on commit 9cfe7b4

Please # to comment.