diff --git a/pkg/trie/node/encode.go b/pkg/trie/node/encode.go index 1ad480b1e8..c0025d7330 100644 --- a/pkg/trie/node/encode.go +++ b/pkg/trie/node/encode.go @@ -16,16 +16,16 @@ import ( // of this package, and specified in the Polkadot spec at // https://spec.polkadot.network/#sect-state-storage func (n *Node) Encode(buffer Buffer) (err error) { + if n == nil { + _, err = buffer.Write([]byte{emptyVariant.bits}) + return err + } + err = encodeHeader(n, n.MustBeHashed, buffer) if err != nil { return fmt.Errorf("cannot encode header: %w", err) } - if n == nil { - // only encode the empty variant byte header - return nil - } - keyLE := codec.NibblesToKeyLE(n.PartialKey) _, err = buffer.Write(keyLE) if err != nil { @@ -55,7 +55,7 @@ func (n *Node) Encode(buffer Buffer) (err error) { _, err = buffer.Write(hashedValue.ToBytes()) if err != nil { - return fmt.Errorf("scale encoding storage value: %w", err) + return fmt.Errorf("writing hashed storage value: %w", err) } default: encoder := scale.NewEncoder(buffer) diff --git a/pkg/trie/node/encode_test.go b/pkg/trie/node/encode_test.go index fa2ea50867..496673b5bd 100644 --- a/pkg/trie/node/encode_test.go +++ b/pkg/trie/node/encode_test.go @@ -136,7 +136,7 @@ func Test_Node_Encode(t *testing.T) { PartialKey: []byte{1, 2, 3}, StorageValue: largeValue, IsHashedValue: true, - MustBeHashed: false, + MustBeHashed: true, }, writes: []writeCall{ { @@ -151,7 +151,7 @@ func Test_Node_Encode(t *testing.T) { }, }, wrappedErr: errTest, - errMessage: "encoding hashed storage value: test error", + errMessage: "writing hashed storage value: test error", }, "branch_header_encoding_error": { node: &Node{ diff --git a/pkg/trie/node/header.go b/pkg/trie/node/header.go index d450d2ded9..90d43195bf 100644 --- a/pkg/trie/node/header.go +++ b/pkg/trie/node/header.go @@ -11,11 +11,6 @@ import ( // encodeHeader writes the encoded header for the node. func encodeHeader(node *Node, isHashedValue bool, writer io.Writer) (err error) { - if node == nil { - _, err = writer.Write([]byte{emptyVariant.bits}) - return err - } - partialKeyLength := len(node.PartialKey) if partialKeyLength > int(maxPartialKeyLength) { panic(fmt.Sprintf("partial key length is too big: %d", partialKeyLength))