From 5edbf89541908cd73493d8c08de32cd48945aeb3 Mon Sep 17 00:00:00 2001 From: Kishan Sagathiya Date: Mon, 12 Feb 2024 19:43:43 +0530 Subject: [PATCH] fix: add a limit of number of bytes while scale decoding a slice (#3733) While scale decoding we first read the length of bytes to decode and then we decode that many bytes. Someone could ask us to decode some malicious bytes such that the read length is unreasonably big. In such case, we would have to create a byte slice as big as the length. The length in byte slice is an encoded as `Compact`. Current we are just reading length as uint and not checking if it goes beyond the bounds of uint32. So, we would either panic because of `makeslice: len out of range` or because the asked length would be more than the memory we have available in our machine. We are going to put a check to makes sure that this length is less than max of uint32. --- dot/types/block_test.go | 11 +++++++++++ internal/trie/node/decode_test.go | 8 ++++---- lib/runtime/version_test.go | 2 +- pkg/scale/decode.go | 18 +++++++++++++++++- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/dot/types/block_test.go b/dot/types/block_test.go index aa2e4173eb..e73a4b505e 100644 --- a/dot/types/block_test.go +++ b/dot/types/block_test.go @@ -131,3 +131,14 @@ func TestMustEncodeBlock(t *testing.T) { }) } } + +func TestScaleUnmarshal(t *testing.T) { + block := NewBlock(*NewEmptyHeader(), Body{}) + err := scale.Unmarshal( + []byte{48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4, 48, 48, 48, 48, 19, 48, 48, 48, 48, 48, 48, 48, 48}, //nolint + &block, + ) + + require.EqualError(t, err, + "decoding struct: unmarshalling field at index 0: decoding struct: unmarshalling field at index 4: decoding struct: unmarshalling field at index 1: byte array length 3472328296227680304 exceeds max value of uint32") //nolint +} diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 49066d641e..5d12bff3e0 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -211,7 +211,7 @@ func Test_decodeBranch(t *testing.T) { nodeVariant: branchVariant, partialKeyLength: 1, errWrapped: ErrDecodeChildHash, - errMessage: "cannot decode child hash: at index 10: reading byte: EOF", + errMessage: "cannot decode child hash: at index 10: decoding uint: reading byte: EOF", }, "success_for_branch_variant": { reader: bytes.NewBuffer( @@ -246,7 +246,7 @@ func Test_decodeBranch(t *testing.T) { nodeVariant: branchWithValueVariant, partialKeyLength: 1, errWrapped: ErrDecodeStorageValue, - errMessage: "cannot decode storage value: reading byte: EOF", + errMessage: "cannot decode storage value: decoding uint: reading byte: EOF", }, "success_for_branch_with_value": { reader: bytes.NewBuffer(concatByteSlices([][]byte{ @@ -372,7 +372,7 @@ func Test_decodeLeaf(t *testing.T) { variant: leafVariant, partialKeyLength: 1, errWrapped: ErrDecodeStorageValue, - errMessage: "cannot decode storage value: unknown prefix for compact uint: 255", + errMessage: "cannot decode storage value: decoding uint: unknown prefix for compact uint: 255", }, "missing_storage_value_data": { reader: bytes.NewBuffer([]byte{ @@ -382,7 +382,7 @@ func Test_decodeLeaf(t *testing.T) { variant: leafVariant, partialKeyLength: 1, errWrapped: ErrDecodeStorageValue, - errMessage: "cannot decode storage value: reading byte: EOF", + errMessage: "cannot decode storage value: decoding uint: reading byte: EOF", }, "empty_storage_value_data": { reader: bytes.NewBuffer(concatByteSlices([][]byte{ diff --git a/lib/runtime/version_test.go b/lib/runtime/version_test.go index 7525c1e08e..d553731795 100644 --- a/lib/runtime/version_test.go +++ b/lib/runtime/version_test.go @@ -39,7 +39,7 @@ func Test_DecodeVersion(t *testing.T) { {255, 255}, // error }), errWrapped: ErrDecodingVersionField, - errMessage: "decoding version field impl name: unknown prefix for compact uint: 255", + errMessage: "decoding version field impl name: decoding uint: unknown prefix for compact uint: 255", }, // TODO add transaction version decode error once // https://github.com/ChainSafe/gossamer/pull/2683 diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 15b8af8eb6..af52cf5cd7 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "math" "math/big" "reflect" ) @@ -475,6 +476,7 @@ func (ds *decodeState) decodeBool(dstv reflect.Value) (err error) { return } +// TODO: Should this be renamed to decodeCompactInt? // decodeUint will decode unsigned integer func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) { const maxUint32 = ^uint32(0) @@ -491,8 +493,12 @@ func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) { var value uint64 switch mode { case 0: + // 0b00: single-byte mode; upper six bits are the LE encoding of the value (valid only for + // values of 0-63). value = uint64(prefix >> 2) case 1: + // 0b01: two-byte mode: upper six bits and the following byte is the LE encoding of the + // value (valid only for values 64-(2**14-1)) buf, err := ds.ReadByte() if err != nil { return fmt.Errorf("reading byte: %w", err) @@ -502,6 +508,8 @@ func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) { return fmt.Errorf("%w: %d (%b)", ErrU16OutOfRange, value, value) } case 2: + // 0b10: four-byte mode: upper six bits and the following three bytes are the LE encoding + // of the value (valid only for values (2**14)-(2**30-1)). buf := make([]byte, 3) _, err = ds.Read(buf) if err != nil { @@ -512,6 +520,9 @@ func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) { return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value) } case 3: + // 0b11: Big-integer mode: The upper six bits are the number of bytes following, plus four. + // The value is contained, LE encoded, in the bytes following. The final (most significant) + // byte must be non-zero. Valid only for values (2**30)-(2**536-1). byteLen := (prefix >> 2) + 4 buf := make([]byte, byteLen) _, err = ds.Read(buf) @@ -557,7 +568,7 @@ func (ds *decodeState) decodeLength() (l uint, err error) { dstv := reflect.New(reflect.TypeOf(l)) err = ds.decodeUint(dstv.Elem()) if err != nil { - return + return 0, fmt.Errorf("decoding uint: %w", err) } l = dstv.Elem().Interface().(uint) return @@ -570,6 +581,11 @@ func (ds *decodeState) decodeBytes(dstv reflect.Value) (err error) { return } + // bytes length in encoded as Compact, so it can't be more than math.MaxUint32 + if length > math.MaxUint32 { + return fmt.Errorf("byte array length %d exceeds max value of uint32", length) + } + b := make([]byte, length) if length > 0 {