diff --git a/scripts/trie_state_script.go b/scripts/trie_state_script.go new file mode 100644 index 0000000000..0348134c54 --- /dev/null +++ b/scripts/trie_state_script.go @@ -0,0 +1,138 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/ChainSafe/gossamer/dot/rpc/modules" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/tests/utils/rpc" +) + +func fetchWithTimeout(ctx context.Context, + method, params string, target interface{}) { + + // Can adjust timeout as desired, default is very long + getResponseCtx, getResponseCancel := context.WithTimeout(ctx, 1000000*time.Second) + defer getResponseCancel() + err := getResponse(getResponseCtx, method, params, target) + if err != nil { + panic(fmt.Sprintf("error getting response %v", err)) + } +} + +func getResponse(ctx context.Context, method, params string, target interface{}) (err error) { + const rpcPort = "8545" + endpoint := rpc.NewEndpoint(rpcPort) + respBody, err := rpc.Post(ctx, endpoint, method, params) + if err != nil { + return fmt.Errorf("cannot RPC post: %w", err) + } + + err = rpc.Decode(respBody, &target) + if err != nil { + return fmt.Errorf("cannot decode RPC response: %w", err) + } + + return nil +} + +func writeTrieState(response modules.StateTrieResponse, destination string) { + encResponse, err := json.Marshal(response) + if err != nil { + panic(fmt.Sprintf("json marshalling response %v", err)) + } + + err = os.WriteFile(destination, encResponse, 0o600) + if err != nil { + panic(fmt.Sprintf("writing to file %v", err)) + } +} + +func fetchTrieState(ctx context.Context, blockHash common.Hash, destination string) modules.StateTrieResponse { + params := fmt.Sprintf(`["%s"]`, blockHash) + var response modules.StateTrieResponse + fetchWithTimeout(ctx, "state_trie", params, &response) + + writeTrieState(response, destination) + return response +} + +func compareStateRoots(response modules.StateTrieResponse, expectedStateRoot common.Hash, trieVersion trie.TrieLayout) { + entries := make(map[string]string, len(response)) + for _, encodedEntry := range response { + bytesEncodedEntry := common.MustHexToBytes(encodedEntry) + + entry := trie.Entry{} + err := scale.Unmarshal(bytesEncodedEntry, &entry) + if err != nil { + panic(fmt.Sprintf("error unmarshalling into trie entry %v", err)) + } + entries[common.BytesToHex(entry.Key)] = common.BytesToHex(entry.Value) + } + + newTrie, err := trie.LoadFromMap(entries) + if err != nil { + panic(fmt.Sprintf("loading trie from map %v", err)) + } + + trieHash := trieVersion.MustHash(newTrie) + if expectedStateRoot != trieHash { + panic("westendDevStateRoot does not match trieHash") + } +} + +/* +This is a script to query the trie state from a specific block height from a running node. + +Example commands to run a node: + + 1. ./bin/gossamer init --chain westend-dev --key alice + + 2. ./bin/gossamer --chain westend-dev --key alice --rpc-external=true --unsafe-rpc=true + +Once the node has started and processed the block whose state you need, can execute the script like so: + 1. go run trieStateScript.go +*/ +func main() { + if len(os.Args) < 3 { + panic("expected more arguments, block hash and destination file required") + } + + blockHash, err := common.HexToHash(os.Args[1]) + if err != nil { + panic("block hash must be in hex format") + } + + destinationFile := os.Args[2] + expectedStateRoot := common.Hash{} + var trieVersion trie.TrieLayout + if len(os.Args) == 5 { + expectedStateRoot, err = common.HexToHash(os.Args[3]) + if err != nil { + panic("expected state root must be in hex format") + } + + trieVersion, err = trie.ParseVersion(os.Args[4]) + if err != nil { + panic("trie version must be an integer") + } + } else if len(os.Args) != 3 { + panic("invalid number of arguments") + } + + ctx, _ := context.WithCancel(context.Background()) //nolint + response := fetchTrieState(ctx, blockHash, destinationFile) + + if !expectedStateRoot.IsEmpty() { + compareStateRoots(response, expectedStateRoot, trieVersion) + } +} diff --git a/scripts/trie_state_script_test.go b/scripts/trie_state_script_test.go new file mode 100644 index 0000000000..06bbedf654 --- /dev/null +++ b/scripts/trie_state_script_test.go @@ -0,0 +1,126 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package main + +import ( + "os" + "testing" + + "github.com/ChainSafe/gossamer/dot/rpc/modules" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" + "github.com/stretchr/testify/require" +) + +// This is fake data used just for testing purposes +var testStateData = []string{"0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000"} //nolint + +func clean(t *testing.T, file string) { + t.Helper() + err := os.Remove(file) + require.NoError(t, err) +} + +func Test_writeTrieState(t *testing.T) { + writeTrieState(testStateData, "westendDevTestState.json") + _, err := os.Stat("./westendDevTestState.json") + require.NoError(t, err) + + clean(t, "westendDevTestState.json") +} + +func Test_compareStateRoots(t *testing.T) { + type args struct { + response modules.StateTrieResponse + expectedStateRoot common.Hash + trieVersion trie.TrieLayout + } + tests := []struct { + name string + args args + shouldPanic bool + }{ + { + name: "happy_path", + args: args{ + response: testStateData, + expectedStateRoot: common.MustHexToHash("0x3b1863ff981a31864be76037e4cf5c927b937dd8a8e1e25494128da7a95b5cdf"), + trieVersion: 0, + }, + }, + { + name: "invalid_trie_version", + args: args{ + response: testStateData, + expectedStateRoot: common.MustHexToHash("0x6120d3afde6c139305bd7c0dcf50bdff5b620203e00c7491b2c30f95dccacc32"), + trieVersion: 21, + }, + shouldPanic: true, + }, + { + name: "hashes_do_not_match", + args: args{ + response: testStateData, + expectedStateRoot: common.MustHexToHash("0x01"), + trieVersion: 21, + }, + shouldPanic: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + require.Panics(t, + func() { + compareStateRoots(tt.args.response, tt.args.expectedStateRoot, tt.args.trieVersion) + }, + "The code did not panic") + } else { + compareStateRoots(tt.args.response, tt.args.expectedStateRoot, tt.args.trieVersion) + } + }) + } +} + +func Test_cli(t *testing.T) { + tests := []struct { + name string + args []string + }{ + { + name: "no_arguments", + }, + { + name: "to_few_arguments", + args: []string{"0x01"}, + }, + { + name: "invalid_formatting_for_block_hash", + args: []string{"hello", "output.json"}, + }, + { + name: "no_trie_version", + args: []string{"0x01", "output.json", "0x01"}, + }, + { + name: "invalid_formatting_for_root_hash", + args: []string{"0x01", "output.json", "hello", "1"}, + }, + { + name: "invalid_trie_version", + args: []string{"0x01", "output.json", "0x01", "hello"}, + }, + { + name: "to_many_arguments", + args: []string{"0x01", "output.json", "0x01", "1", "0x01"}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + os.Args = tt.args + require.Panics(t, func() { main() }, "The code did not panic") + }) + } +}