diff --git a/.gitignore b/.gitignore index aa1e6f0a..8536fd50 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,6 @@ rpc*/.env.testnet tmp/ -examples/**/*.json examples/**/*.sum */**/*abi.json diff --git a/README.md b/README.md index 4af8e85a..e9f8b1cc 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ operations on the wallets. The package has excellent documentation for a smooth - [deploy account example](./examples/deployAccount) to deploy a new account contract on testnet. - [invoke transaction example](./examples/simpleInvoke) to add a new invoke transaction on testnet. - [deploy contract UDC example](./examples/deployContractUDC) to deploy an ERC20 token using [UDC (Universal Deployer Contract)](https://docs.starknet.io/architecture-and-concepts/accounts/universal-deployer/) on testnet. +- [typed data example](./examples/typedData) to sign and verify a typed data. ### Run Examples diff --git a/curve/curve.go b/curve/curve.go index a188fbfb..d1d93dc2 100644 --- a/curve/curve.go +++ b/curve/curve.go @@ -568,6 +568,19 @@ func Pedersen(a, b *felt.Felt) *felt.Felt { return junoCrypto.Pedersen(a, b) } +// Poseidon is a function that implements the Poseidon hash. +// NOTE: This function just wraps the Juno implementation +// (ref: https://github.com/NethermindEth/juno/blob/32fd743c774ec11a1bb2ce3dceecb57515f4873e/core/crypto/poseidon_hash.go#L59) +// +// Parameters: +// - a: a pointers to felt.Felt to be hashed. +// - b: a pointers to felt.Felt to be hashed. +// Returns: +// - *felt.Felt: a pointer to a felt.Felt storing the resulting hash. +func Poseidon(a, b *felt.Felt) *felt.Felt { + return junoCrypto.Poseidon(a, b) +} + // PedersenArray is a function that takes a variadic number of felt.Felt pointers as parameters and // calls the PedersenArray function from the junoCrypto package with the provided parameters. // NOTE: This function just wraps the Juno implementation @@ -590,7 +603,7 @@ func PedersenArray(felts ...*felt.Felt) *felt.Felt { // - felts: A variadic number of pointers to felt.Felt // Returns: // - *felt.Felt: pointer to a felt.Felt -func (sc StarkCurve) PoseidonArray(felts ...*felt.Felt) *felt.Felt { +func PoseidonArray(felts ...*felt.Felt) *felt.Felt { return junoCrypto.PoseidonArray(felts...) } @@ -603,7 +616,7 @@ func (sc StarkCurve) PoseidonArray(felts ...*felt.Felt) *felt.Felt { // Returns: // - *felt.Felt: pointer to a felt.Felt // - error: An error if any -func (sc StarkCurve) StarknetKeccak(b []byte) *felt.Felt { +func StarknetKeccak(b []byte) *felt.Felt { return junoCrypto.StarknetKeccak(b) } @@ -709,3 +722,48 @@ func (sc StarkCurve) PrivateToPoint(privKey *big.Int) (x, y *big.Int, err error) x, y = sc.EcMult(privKey, sc.EcGenX, sc.EcGenY) return x, y, nil } + +// VerifySignature verifies the ECDSA signature of a given message hash using the provided public key. +// +// It takes the message hash, the r and s values of the signature, and the public key as strings and +// verifies the signature using the public key. +// +// Parameters: +// - msgHash: The hash of the message to be verified as a string +// - r: The r value (the first part) of the signature as a string +// - s: The s value (the second part) of the signature as a string +// - pubKey: The public key (only the x coordinate) as a string +// Return values: +// - bool: A boolean indicating whether the signature is valid +// - error: An error if any occurred during the verification process +func VerifySignature(msgHash, r, s, pubKey string) bool { + feltMsgHash, err := new(felt.Felt).SetString(msgHash) + if err != nil { + return false + } + feltR, err := new(felt.Felt).SetString(r) + if err != nil { + return false + } + feltS, err := new(felt.Felt).SetString(s) + if err != nil { + return false + } + pubKeyFelt, err := new(felt.Felt).SetString(pubKey) + if err != nil { + return false + } + + signature := junoCrypto.Signature{ + R: *feltR, + S: *feltS, + } + + pubKeyStruct := junoCrypto.NewPublicKey(pubKeyFelt) + resp, err := pubKeyStruct.Verify(&signature, feltMsgHash) + if err != nil { + return false + } + + return resp +} diff --git a/curve/curve_test.go b/curve/curve_test.go index 711c2475..53d1cd5d 100644 --- a/curve/curve_test.go +++ b/curve/curve_test.go @@ -501,3 +501,28 @@ func TestGeneral_SplitFactStr(t *testing.T) { require.Equal(t, d["h"], h) } } + +// TestGeneral_VerifySignature is a test function that verifies the correctness of the VerifySignature function. +// +// It checks if the signature of a given message hash is valid using the provided r, s values and the public key. +// The function takes no parameters and returns no values. +// +// Parameters: +// - t: The testing.T object for running the test +// Returns: +// +// none +func TestGeneral_VerifySignature(t *testing.T) { + // values verified with starknet.js + + msgHash := "0x2789daed76c8b750d5a609a706481034db9dc8b63ae01f505d21e75a8fc2336" + r := "0x13e4e383af407f7ccc1f13195ff31a58cad97bbc6cf1d532798b8af616999d4" + s := "0x44dd06cf67b2ba7ea4af346d80b0b439e02a0b5893c6e4dfda9ee204211c879" + fullPubKey := "0x6c7c4408e178b2999cef9a5b3fa2a3dffc876892ad6a6bd19d1451a2256906c" + + require.True(t, VerifySignature(msgHash, r, s, fullPubKey)) + + // Change the last digit of the message hash to test invalid signature + wrongMsgHash := "0x2789daed76c8b750d5a609a706481034db9dc8b63ae01f505d21e75a8fc2337" + require.False(t, VerifySignature(wrongMsgHash, r, s, fullPubKey)) +} diff --git a/examples/README.md b/examples/README.md index b1618434..d2035384 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,4 +36,6 @@ To run an example: R: See [simpleCall](./simpleCall/main.go). 1. How to make a function call? R: See [simpleCall](./simpleCall/main.go). +1. How to sign and verify a typed data? + R: See [typedData](./typedData/main.go). diff --git a/examples/typedData/README.md b/examples/typedData/README.md new file mode 100644 index 00000000..82d785e9 --- /dev/null +++ b/examples/typedData/README.md @@ -0,0 +1,12 @@ +This example shows how to sign and verify a typed data. + +Steps: +1. Rename the ".env.template" file located at the root of the "examples" folder to ".env" +1. Uncomment, and assign your Sepolia testnet endpoint to the `RPC_PROVIDER_URL` variable in the ".env" file +1. Uncomment, and assign your account address to the `ACCOUNT_ADDRESS` variable in the ".env" file (make sure to have a few ETH in it) +1. Uncomment, and assign your starknet public key to the `PUBLIC_KEY` variable in the ".env" file +1. Uncomment, and assign your private key to the `PRIVATE_KEY` variable in the ".env" file +1. Make sure you are in the "typedData" directory +1. Execute `go run main.go` + +The message hash, signature and the verification result will be printed at the end of the execution. diff --git a/examples/typedData/baseExample.json b/examples/typedData/baseExample.json new file mode 100644 index 00000000..0780ecd9 --- /dev/null +++ b/examples/typedData/baseExample.json @@ -0,0 +1,35 @@ +{ + "types": { + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chainId", "type": "felt" } + ], + "Person": [ + { "name": "name", "type": "felt" }, + { "name": "wallet", "type": "felt" } + ], + "Mail": [ + { "name": "from", "type": "Person" }, + { "name": "to", "type": "Person" }, + { "name": "contents", "type": "felt" } + ] + }, + "primaryType": "Mail", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": 1 + }, + "message": { + "from": { + "name": "Cow", + "wallet": "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826" + }, + "to": { + "name": "Bob", + "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB" + }, + "contents": "Hello, Bob!" + } +} \ No newline at end of file diff --git a/examples/typedData/main.go b/examples/typedData/main.go new file mode 100644 index 00000000..efc11fdd --- /dev/null +++ b/examples/typedData/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/big" + "os" + + "github.com/NethermindEth/starknet.go/account" + "github.com/NethermindEth/starknet.go/curve" + "github.com/NethermindEth/starknet.go/rpc" + "github.com/NethermindEth/starknet.go/typedData" + "github.com/NethermindEth/starknet.go/utils" + + setup "github.com/NethermindEth/starknet.go/examples/internal" +) + +// NOTE : Please add in your keys only for testing purposes, in case of a leak you would potentially lose your funds. + +func main() { + // Setup the account + accnt := localSetup() + fmt.Println("Account address:", accnt.AccountAddress) + + // This is how you can initialize a typed data from a JSON file + var ttd typedData.TypedData + content, err := os.ReadFile("./baseExample.json") + if err != nil { + panic(fmt.Errorf("fail to read file: %w", err)) + } + err = json.Unmarshal(content, &ttd) + if err != nil { + panic(fmt.Errorf("fail to unmarshal TypedData: %w", err)) + } + + // This is how you can get the message hash linked to your account address + messageHash, err := ttd.GetMessageHash(accnt.AccountAddress.String()) + if err != nil { + panic(fmt.Errorf("fail to get message hash: %w", err)) + } + fmt.Println("Message hash:", messageHash) + + // This is how you can sign the message hash + signature, err := accnt.Sign(context.Background(), messageHash) + if err != nil { + panic(fmt.Errorf("fail to sign message: %w", err)) + } + fmt.Println("Signature:", signature) + + // This is how you can verify the signature + isValid := curve.VerifySignature(messageHash.String(), signature[0].String(), signature[1].String(), setup.GetPublicKey()) + fmt.Println("Verification result:", isValid) +} + +func localSetup() *account.Account { + // Load variables from '.env' file + rpcProviderUrl := setup.GetRpcProviderUrl() + accountAddress := setup.GetAccountAddress() + accountCairoVersion := setup.GetAccountCairoVersion() + privateKey := setup.GetPrivateKey() + publicKey := setup.GetPublicKey() + + // Initialize connection to RPC provider + client, err := rpc.NewProvider(rpcProviderUrl) + if err != nil { + panic(fmt.Sprintf("Error dialing the RPC provider: %s", err)) + } + + // Initialize the account memkeyStore (set public and private keys) + ks := account.NewMemKeystore() + privKeyBI, ok := new(big.Int).SetString(privateKey, 0) + if !ok { + panic("Fail to convert privKey to bitInt") + } + ks.Put(publicKey, privKeyBI) + + // Here we are converting the account address to felt + accountAddressInFelt, err := utils.HexToFelt(accountAddress) + if err != nil { + fmt.Println("Failed to transform the account address, did you give the hex address?") + panic(err) + } + // Initialize the account + accnt, err := account.NewAccount(client, accountAddressInFelt, publicKey, ks, accountCairoVersion) + if err != nil { + panic(err) + } + + return accnt +} diff --git a/hash/hash.go b/hash/hash.go index 71d53cdc..846a5f31 100644 --- a/hash/hash.go +++ b/hash/hash.go @@ -65,11 +65,11 @@ func ClassHash(contract rpc.ContractClass) *felt.Felt { ConstructorHash := hashEntryPointByType(contract.EntryPointsByType.Constructor) ExternalHash := hashEntryPointByType(contract.EntryPointsByType.External) L1HandleHash := hashEntryPointByType(contract.EntryPointsByType.L1Handler) - SierraProgamHash := curve.Curve.PoseidonArray(contract.SierraProgram...) - ABIHash := curve.Curve.StarknetKeccak([]byte(contract.ABI)) + SierraProgamHash := curve.PoseidonArray(contract.SierraProgram...) + ABIHash := curve.StarknetKeccak([]byte(contract.ABI)) // https://docs.starknet.io/documentation/architecture_and_concepts/Network_Architecture/transactions/#deploy_account_hash_calculation - return curve.Curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ABIHash, SierraProgamHash) + return curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ABIHash, SierraProgamHash) } // hashEntryPointByType calculates the hash of an entry point by type. @@ -83,7 +83,7 @@ func hashEntryPointByType(entryPoint []rpc.SierraEntryPoint) *felt.Felt { for _, elt := range entryPoint { flattened = append(flattened, elt.Selector, new(felt.Felt).SetUint64(uint64(elt.FunctionIdx))) } - return curve.Curve.PoseidonArray(flattened...) + return curve.PoseidonArray(flattened...) } // CompiledClassHash calculates the hash of a compiled class in the Casm format. @@ -97,10 +97,10 @@ func CompiledClassHash(casmClass contracts.CasmClass) *felt.Felt { ExternalHash := hashCasmClassEntryPointByType(casmClass.EntryPointByType.External) L1HandleHash := hashCasmClassEntryPointByType(casmClass.EntryPointByType.L1Handler) ConstructorHash := hashCasmClassEntryPointByType(casmClass.EntryPointByType.Constructor) - ByteCodeHasH := curve.Curve.PoseidonArray(casmClass.ByteCode...) + ByteCodeHasH := curve.PoseidonArray(casmClass.ByteCode...) // https://github.com/software-mansion/starknet.py/blob/development/starknet_py/hash/casm_class_hash.py#L10 - return curve.Curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ByteCodeHasH) + return curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ByteCodeHasH) } // hashCasmClassEntryPointByType calculates the hash of a CasmClassEntryPoint array. @@ -116,8 +116,8 @@ func hashCasmClassEntryPointByType(entryPoint []contracts.CasmClassEntryPoint) * for _, builtIn := range elt.Builtins { builtInFlat = append(builtInFlat, new(felt.Felt).SetBytes([]byte(builtIn))) } - builtInHash := curve.Curve.PoseidonArray(builtInFlat...) + builtInHash := curve.PoseidonArray(builtInFlat...) flattened = append(flattened, elt.Selector, new(felt.Felt).SetUint64(uint64(elt.Offset)), builtInHash) } - return curve.Curve.PoseidonArray(flattened...) + return curve.PoseidonArray(flattened...) } diff --git a/typed/typed.go b/typed/typed.go deleted file mode 100644 index f1b89915..00000000 --- a/typed/typed.go +++ /dev/null @@ -1,239 +0,0 @@ -package typed - -import ( - "bytes" - "encoding/hex" - "fmt" - "math/big" - "regexp" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/starknet.go/curve" - "github.com/NethermindEth/starknet.go/utils" -) - -type TypedData struct { - Types map[string]TypeDef - PrimaryType string - Domain Domain - Message TypedMessage -} - -type Domain struct { - Name string - Version string - ChainId string -} - -type TypeDef struct { - Encoding *big.Int - Definitions []Definition -} - -type Definition struct { - Name string - Type string -} - -type TypedMessage interface { - FmtDefinitionEncoding(string) []*big.Int -} - -// FmtDefinitionEncoding formats the definition (standard Starknet Domain) encoding. -// -// Parameters: -// - field: the field to format the encoding for -// Returns: -// - fmtEnc: a slice of big integers -func (dm Domain) FmtDefinitionEncoding(field string) (fmtEnc []*big.Int) { - processStrToBig := func(fieldVal string) { - feltVal := strToFelt(fieldVal) - bigInt := utils.FeltToBigInt(feltVal) - fmtEnc = append(fmtEnc, bigInt) - } - - switch field { - case "name": - processStrToBig(dm.Name) - case "version": - processStrToBig(dm.Version) - case "chainId": - processStrToBig(dm.ChainId) - } - return fmtEnc -} - -// strToFelt converts a string (decimal, hexadecimal or UTF8 charset) to a *felt.Felt. -// -// Parameters: -// - str: the string to convert to a *felt.Felt -// Returns: -// - *felt.Felt: a *felt.Felt with the value of str -func strToFelt(str string) *felt.Felt { - var f = new(felt.Felt) - asciiRegexp := regexp.MustCompile(`^([[:graph:]]|[[:space:]]){1,31}$`) - - if b, ok := new(big.Int).SetString(str, 0); ok { - f.SetBytes(b.Bytes()) - return f - } - // TODO: revisit conversation on seperate 'ShortString' conversion - if asciiRegexp.MatchString(str) { - hexStr := hex.EncodeToString([]byte(str)) - if b, ok := new(big.Int).SetString(hexStr, 16); ok { - f.SetBytes(b.Bytes()) - return f - } - } - - return f -} - -// NewTypedData initializes a new TypedData object with the given types, primary type, and domain -// for interacting and signing in accordance with https://github.com/0xs34n/starknet.js/tree/develop/src/utils/typedData -// If the primary type is invalid, it returns an error with the message "invalid primary type: {pType}". -// If there is an error encoding the type hash, it returns an error with the message "error encoding type hash: {enc.String()} {err}". -// -// Parameters: -// - types: a map[string]TypeDef representing the types associated with their names. -// - pType: a string representing the primary type. -// - dom: a Domain representing the domain. -// Returns: -// - td: a TypedData object -// - err: an error if any -func NewTypedData(types map[string]TypeDef, pType string, dom Domain) (td TypedData, err error) { - td = TypedData{ - Types: types, - PrimaryType: pType, - Domain: dom, - } - if _, ok := td.Types[pType]; !ok { - return td, fmt.Errorf("invalid primary type: %s", pType) - } - - for k, v := range td.Types { - enc, err := td.GetTypeHash(k) - if err != nil { - return td, fmt.Errorf("error encoding type hash: %s %w", enc.String(), err) - } - v.Encoding = enc - td.Types[k] = v - } - return td, nil -} - -// GetMessageHash calculates the hash of a typed message for a given account using the StarkCurve. -// (ref: https://github.com/0xs34n/starknet.js/blob/767021a203ac0b9cdb282eb6d63b33bfd7614858/src/utils/typedData/index.ts#L166) -// -// Parameters: -// - account: A pointer to a big.Int representing the account. -// - msg: A TypedMessage object representing the message. -// Returns: -// - hash: A pointer to a big.Int representing the calculated hash. -func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage) (hash *big.Int) { - elements := []*big.Int{utils.UTF8StrToBig("StarkNet Message")} - - domEnc := td.GetTypedMessageHash("StarkNetDomain", td.Domain) - - elements = append(elements, domEnc) - elements = append(elements, account) - - msgEnc := td.GetTypedMessageHash(td.PrimaryType, msg) - - elements = append(elements, msgEnc) - - return curve.ComputeHashOnElements(elements) -} - -// GetTypedMessageHash calculates the hash of a typed message using the provided StarkCurve. -// -// Parameters: -// - inType: the type of the message -// - msg: the typed message -// -// Returns: -// - hash: the calculated hash -func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage) (hash *big.Int) { - prim := td.Types[inType] - elements := []*big.Int{prim.Encoding} - - for _, def := range prim.Definitions { - if def.Type == "felt" { - fmtDefinitions := msg.FmtDefinitionEncoding(def.Name) - elements = append(elements, fmtDefinitions...) - continue - } - - innerElements := []*big.Int{} - encType := td.Types[def.Type] - innerElements = append(innerElements, encType.Encoding) - fmtDefinitions := msg.FmtDefinitionEncoding(def.Name) - innerElements = append(innerElements, fmtDefinitions...) - innerElements = append(innerElements, big.NewInt(int64(len(innerElements)))) - - innerHash := curve.HashPedersenElements(innerElements) - elements = append(elements, innerHash) - } - - return curve.ComputeHashOnElements(elements) -} - -// GetTypeHash returns the hash of the given type. -// -// Parameters: -// - inType: the type to hash -// Returns: -// - ret: the hash of the given type -// - err: any error if any -func (td TypedData) GetTypeHash(inType string) (ret *big.Int, err error) { - enc, err := td.EncodeType(inType) - if err != nil { - return ret, err - } - return utils.GetSelectorFromName(enc), nil -} - -// EncodeType encodes the given inType using the TypedData struct. -// -// Parameters: -// - inType: the type to encode -// Returns: -// - enc: the encoded type -// - err: any error if any -func (td TypedData) EncodeType(inType string) (enc string, err error) { - var typeDefs TypeDef - var ok bool - if typeDefs, ok = td.Types[inType]; !ok { - return enc, fmt.Errorf("can't parse type %s from types %v", inType, td.Types) - } - var buf bytes.Buffer - customTypes := make(map[string]TypeDef) - buf.WriteString(inType) - buf.WriteString("(") - for i, def := range typeDefs.Definitions { - if def.Type != "felt" { - var customTypeDef TypeDef - if customTypeDef, ok = td.Types[def.Type]; !ok { - return enc, fmt.Errorf("can't parse type %s from types %v", def.Type, td.Types) - } - customTypes[def.Type] = customTypeDef - } - buf.WriteString(fmt.Sprintf("%s:%s", def.Name, def.Type)) - if i != (len(typeDefs.Definitions) - 1) { - buf.WriteString(",") - } - } - buf.WriteString(")") - - for customTypeName, customType := range customTypes { - buf.WriteString(fmt.Sprintf("%s(", customTypeName)) - for i, def := range customType.Definitions { - buf.WriteString(fmt.Sprintf("%s:%s", def.Name, def.Type)) - if i != (len(customType.Definitions) - 1) { - buf.WriteString(",") - } - } - buf.WriteString(")") - } - return buf.String(), nil -} diff --git a/typed/typed_test.go b/typed/typed_test.go deleted file mode 100644 index f58d87bc..00000000 --- a/typed/typed_test.go +++ /dev/null @@ -1,280 +0,0 @@ -package typed - -import ( - "fmt" - "math/big" - "testing" - - "github.com/NethermindEth/starknet.go/utils" - "github.com/stretchr/testify/require" -) - -type Mail struct { - From Person - To Person - Contents string -} - -type Person struct { - Name string - Wallet string -} - -// FmtDefinitionEncoding formats the encoding for the given field in the Mail struct. -// -// Parameters: -// - field: the field to format the encoding for -// Returns: -// - fmtEnc: a slice of big integers -func (mail Mail) FmtDefinitionEncoding(field string) (fmtEnc []*big.Int) { - if field == "from" { - fmtEnc = append(fmtEnc, utils.UTF8StrToBig(mail.From.Name)) - fmtEnc = append(fmtEnc, utils.HexToBN(mail.From.Wallet)) - } else if field == "to" { - fmtEnc = append(fmtEnc, utils.UTF8StrToBig(mail.To.Name)) - fmtEnc = append(fmtEnc, utils.HexToBN(mail.To.Wallet)) - } else if field == "contents" { - fmtEnc = append(fmtEnc, utils.UTF8StrToBig(mail.Contents)) - } - return fmtEnc -} - -// MockTypedData generates a TypedData object for testing purposes. -// It creates example types and initializes a Domain object. Then it uses the example types and the domain to create a new TypedData object. -// The function returns the generated TypedData object. -// -// Parameters: -// -// none -// -// Returns: -// - ttd: the generated TypedData object -func MockTypedData() (ttd TypedData, err error) { - exampleTypes := make(map[string]TypeDef) - domDefs := []Definition{{"name", "felt"}, {"version", "felt"}, {"chainId", "felt"}} - exampleTypes["StarkNetDomain"] = TypeDef{Definitions: domDefs} - mailDefs := []Definition{{"from", "Person"}, {"to", "Person"}, {"contents", "felt"}} - exampleTypes["Mail"] = TypeDef{Definitions: mailDefs} - persDefs := []Definition{{"name", "felt"}, {"wallet", "felt"}} - exampleTypes["Person"] = TypeDef{Definitions: persDefs} - - dm := Domain{ - Name: "StarkNet Mail", - Version: "1", - ChainId: "1", - } - - ttd, err = NewTypedData(exampleTypes, "Mail", dm) - if err != nil { - return TypedData{}, err - } - return ttd, err -} - -// TestGeneral_GetMessageHash tests the GetMessageHash function. -// -// It creates a mock TypedData and sets up a test case for hashing a mail message. -// The mail message contains information about the sender and recipient, as well as the contents of the message. -// The function then calls the GetMessageHash function with the necessary parameters to calculate the message hash. -// If an error occurs during the hashing process, an error is reported using the t.Errorf function. -// The expected hash value is compared with the actual hash value returned by the function. -// If the values do not match, an error is reported using the t.Errorf function. -// -// Parameters: -// - t: a testing.T object that provides methods for testing functions -// Returns: -// - None -func TestGeneral_GetMessageHash(t *testing.T) { - ttd, err := MockTypedData() - require.NoError(t, err) - - mail := Mail{ - From: Person{ - Name: "Cow", - Wallet: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", - }, - To: Person{ - Name: "Bob", - Wallet: "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", - }, - Contents: "Hello, Bob!", - } - - hash := ttd.GetMessageHash(utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), mail) - - exp := "0x6fcff244f63e38b9d88b9e3378d44757710d1b244282b435cb472053c8d78d0" - require.Equal(t, exp, utils.BigToHex(hash)) -} - -// BenchmarkGetMessageHash is a benchmark function for testing the GetMessageHash function. -// -// It tests the performance of the GetMessageHash function by running it with different input sizes. -// The input size is determined by the bit length of the address parameter, which is converted from -// a hexadecimal string to a big integer using the HexToBN function from the utils package. -// -// Parameters: -// - b: a testing.B object that provides methods for benchmarking the function -// Returns: -// -// none -func BenchmarkGetMessageHash(b *testing.B) { - ttd, err := MockTypedData() - require.NoError(b, err) - - mail := Mail{ - From: Person{ - Name: "Cow", - Wallet: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", - }, - To: Person{ - Name: "Bob", - Wallet: "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", - }, - Contents: "Hello, Bob!", - } - addr := utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826") - b.Run(fmt.Sprintf("input_size_%d", addr.BitLen()), func(b *testing.B) { - result := ttd.GetMessageHash(addr, mail) - require.NotEmpty(b, result) - }) -} - -// TestGeneral_GetDomainHash tests the GetDomainHash function. -// It creates a mock TypedData object and generates the hash of a typed message using the Starknet domain and curve. -// If there is an error during the hashing process, it logs the error. -// It then compares the generated hash with the expected hash and logs an error if they do not match. -// -// Parameters: -// - t: a testing.T object that provides methods for testing functions -// Returns: -// -// none -func TestGeneral_GetDomainHash(t *testing.T) { - ttd, err := MockTypedData() - require.NoError(t, err) - - hash := ttd.GetTypedMessageHash("StarkNetDomain", ttd.Domain) - - exp := "0x54833b121883a3e3aebff48ec08a962f5742e5f7b973469c1f8f4f55d470b07" - require.Equal(t, exp, utils.BigToHex(hash)) -} - -// TestGeneral_GetTypedMessageHash is a unit test for the GetTypedMessageHash function -// equivalent of get struct hash. -// -// It tests the generation of a typed message hash for a given mail object using a specific curve. -// The function expects the mail object to have a "From" field of type Person, a "To" field of type Person, -// and a "Contents" field of type string. It returns the generated hash as a byte array and an error object. -// -// Parameters: -// - t: a testing.T object that provides methods for testing functions -// Returns: -// -// none -func TestGeneral_GetTypedMessageHash(t *testing.T) { - ttd, err := MockTypedData() - require.NoError(t, err) - - mail := Mail{ - From: Person{ - Name: "Cow", - Wallet: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", - }, - To: Person{ - Name: "Bob", - Wallet: "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", - }, - Contents: "Hello, Bob!", - } - - hash := ttd.GetTypedMessageHash("Mail", mail) - - exp := "0x4758f1ed5e7503120c228cbcaba626f61514559e9ef5ed653b0b885e0f38aec" - require.Equal(t, exp, utils.BigToHex(hash)) -} - -// TestGeneral_GetTypeHash tests the GetTypeHash function. -// -// It tests the GetTypeHash function by calling it with different input values -// and comparing the result with expected values. It also checks that the -// encoding of the types matches the expected values. -// -// Parameters: -// - t: The testing.T object used for reporting test failures and logging test output -// Returns: -// -// none -func TestGeneral_GetTypeHash(t *testing.T) { - require := require.New(t) - - ttd, err := MockTypedData() - require.NoError(err) - - hash, err := ttd.GetTypeHash("StarkNetDomain") - require.NoError(err) - - exp := "0x1bfc207425a47a5dfa1a50a4f5241203f50624ca5fdf5e18755765416b8e288" - require.Equal(exp, utils.BigToHex(hash)) - - enc := ttd.Types["StarkNetDomain"] - require.Equal(exp, utils.BigToHex(enc.Encoding)) - - pHash, err := ttd.GetTypeHash("Person") - require.NoError(err) - - exp = "0x2896dbe4b96a67110f454c01e5336edc5bbc3635537efd690f122f4809cc855" - require.Equal(exp, utils.BigToHex(pHash)) - - enc = ttd.Types["Person"] - require.Equal(exp, utils.BigToHex(enc.Encoding)) -} - -// TestGeneral_GetSelectorFromName tests the GetSelectorFromName function. -// -// It checks if the GetSelectorFromName function returns the expected values -// for different input names. -// The expected values are hard-coded and compared against the actual values. -// If any of the actual values do not match the expected values, an error is -// reported. -// -// Parameters: -// - t: The testing.T object used for reporting test failures and logging test output -// Returns: -// -// none -func TestGeneral_GetSelectorFromName(t *testing.T) { - sel1 := utils.BigToHex(utils.GetSelectorFromName("initialize")) - sel2 := utils.BigToHex(utils.GetSelectorFromName("mint")) - sel3 := utils.BigToHex(utils.GetSelectorFromName("test")) - - exp1 := "0x79dc0da7c54b95f10aa182ad0a46400db63156920adb65eca2654c0945a463" - exp2 := "0x2f0b3c5710379609eb5495f1ecd348cb28167711b73609fe565a72734550354" - exp3 := "0x22ff5f21f0b81b113e63f7db6da94fedef11b2119b4088b89664fb9a3cb658" - - if sel1 != exp1 || sel2 != exp2 || sel3 != exp3 { - t.Errorf("invalid Keccak256 encoding: %v %v %v\n", sel1, sel2, sel3) - } -} - -// TestGeneral_EncodeType tests the EncodeType function. -// -// It creates a mock typed data and calls the EncodeType method with the -// parameter "Mail". It checks if the returned encoding matches the expected -// encoding. If there is an error during the encoding process, it fails the -// test. -// -// Parameters: -// - t: The testing.T object used for reporting test failures and logging test output -// Returns: -// -// none -func TestGeneral_EncodeType(t *testing.T) { - ttd, err := MockTypedData() - require.NoError(t, err) - - enc, err := ttd.EncodeType("Mail") - require.NoError(t, err) - - exp := "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)" - require.Equal(t, exp, enc) -} diff --git a/typedData/revision.go b/typedData/revision.go new file mode 100644 index 00000000..4e2780f1 --- /dev/null +++ b/typedData/revision.go @@ -0,0 +1,226 @@ +package typedData + +import ( + "fmt" + "slices" + "strings" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/curve" +) + +var ( + // There is also an array version of each type. The array is defined like this: 'type' + '*' (e.g.: "felt*", "bool*", "string*"...) + revision_0_basic_types []string = []string{ + "felt", + "bool", + "string", //up to 31 ASCII characters + "selector", + "merkletree", + } + + // Revision 1 includes all types from Revision 0 plus these. The only difference is that for Revision 1 "string" represents an + // arbitrary size string instead of having a 31 ASCII characters limit in Revision 0; for this limit, use the new type "shortstring" instead. + // + // There is also an array version of each type. The array is defined like this: 'type' + '*' (e.g.: "ClassHash*", "timestamp*", "shortstring*"...) + revision_1_basic_types []string = []string{ + "enum", + "u128", + "i128", + "ContractAddress", + "ClassHash", + "timestamp", + "shortstring", + } + + //lint:ignore U1000 Variable used to check Preset types in other pieces of code + revision_1_preset_types []string = []string{ + "NftId", + "TokenAmount", + "u256", + } +) + +var RevisionV0 revision +var RevisionV1 revision + +func init() { + presetMap := make(map[string]TypeDefinition) + + RevisionV0 = revision{ + version: 0, + domain: "StarkNetDomain", + hashMethod: curve.PedersenArray, + hashMerkleMethod: curve.Pedersen, + types: RevisionTypes{ + Basic: revision_0_basic_types, + Preset: presetMap, + }, + } + + presetMap = getRevisionV1PresetTypes() + + RevisionV1 = revision{ + version: 1, + domain: "StarknetDomain", + hashMethod: curve.PoseidonArray, + hashMerkleMethod: curve.Poseidon, + types: RevisionTypes{ + Basic: append(revision_1_basic_types, revision_0_basic_types...), + Preset: presetMap, + }, + } +} + +type revision struct { + //TODO: create a enum + version uint8 + domain string + hashMethod func(felts ...*felt.Felt) *felt.Felt + hashMerkleMethod func(a, b *felt.Felt) *felt.Felt + types RevisionTypes +} + +type RevisionTypes struct { + Basic []string + Preset map[string]TypeDefinition +} + +func (rev *revision) Version() uint8 { + return rev.version +} + +func (rev *revision) Domain() string { + return rev.domain +} + +func (rev *revision) HashMethod(felts ...*felt.Felt) *felt.Felt { + return rev.hashMethod(felts...) +} + +func (rev *revision) HashMerkleMethod(a *felt.Felt, b *felt.Felt) *felt.Felt { + var first, second *felt.Felt + if a.Cmp(b) > 0 { + first = b + second = a + } else { + first = a + second = b + } + return rev.hashMerkleMethod(first, second) +} + +func (rev *revision) Types() RevisionTypes { + return rev.types +} + +func GetRevision(version uint8) (rev *revision, err error) { + switch version { + case 0: + return &RevisionV0, nil + case 1: + return &RevisionV1, nil + default: + return rev, fmt.Errorf("invalid revision version") + } +} + +func getRevisionV1PresetTypes() map[string]TypeDefinition { + NftIdEnc, _ := new(felt.Felt).SetString("0xaf7d0f5e34446178d80fadf5ddaaed52347121d2fac19ff184ff508d4776f2") + TokenAmountEnc, _ := new(felt.Felt).SetString("0x14648649d4413eb385eea9ac7e6f2b9769671f5d9d7ad40f7b4aadd67839d4") + u256dEnc, _ := new(felt.Felt).SetString("0x3b143be38b811560b45593fb2a071ec4ddd0a020e10782be62ffe6f39e0e82c") + + presetTypes := []TypeDefinition{ + { + Name: "NftId", + Enconding: NftIdEnc, + EncoddingString: `"NftId"("collection_address":"ContractAddress","token_id":"u256")"u256"("low":"u128","high":"u128")`, + SingleEncString: `"NftId"("collection_address":"ContractAddress","token_id":"u256")`, + ReferencedTypesEnc: []string{`"u256"("low":"u128","high":"u128")`}, + Parameters: []TypeParameter{ + { + Name: "collection_address", + Type: "ContractAddress", + }, + { + Name: "token_id", + Type: "u256", + }, + }, + }, + { + Name: "TokenAmount", + Enconding: TokenAmountEnc, + EncoddingString: `"TokenAmount"("token_address":"ContractAddress","amount":"u256")"u256"("low":"u128","high":"u128")`, + SingleEncString: `"TokenAmount"("token_address":"ContractAddress","amount":"u256")`, + ReferencedTypesEnc: []string{`"u256"("low":"u128","high":"u128")`}, + Parameters: []TypeParameter{ + { + Name: "token_address", + Type: "ContractAddress", + }, + { + Name: "amount", + Type: "u256", + }, + }, + }, + { + Name: "u256", + Enconding: u256dEnc, + EncoddingString: `"u256"("low":"u128","high":"u128")`, + SingleEncString: `"u256"("low":"u128","high":"u128")`, + ReferencedTypesEnc: []string{}, + Parameters: []TypeParameter{ + { + Name: "low", + Type: "u128", + }, + { + Name: "high", + Type: "u128", + }, + }, + }, + } + + presetTypesMap := make(map[string]TypeDefinition) + + for _, typeDef := range presetTypes { + presetTypesMap[typeDef.Name] = typeDef + } + + return presetTypesMap +} + +// Check if the provided type name is a standard type defined at the SNIP 12, also validates arrays +func isStandardType(typeName string) bool { + typeName, _ = strings.CutSuffix(typeName, "*") + + if slices.Contains(revision_0_basic_types, typeName) || + slices.Contains(revision_1_basic_types, typeName) || + slices.Contains(revision_1_preset_types, typeName) { + return true + } + + return false +} + +// Check if the provided type name is a basic type defined at the SNIP 12, also validates arrays +func isBasicType(typeName string) bool { + typeName, _ = strings.CutSuffix(typeName, "*") + + if slices.Contains(revision_0_basic_types, typeName) || + slices.Contains(revision_1_basic_types, typeName) { + return true + } + + return false +} + +// Check if the provided type name is a preset type defined at the SNIP 12, also validates arrays +func isPresetType(typeName string) bool { + typeName, _ = strings.CutSuffix(typeName, "*") + + return slices.Contains(revision_1_preset_types, typeName) +} diff --git a/typedData/tests/allInOne.json b/typedData/tests/allInOne.json new file mode 100644 index 00000000..71b8622d --- /dev/null +++ b/typedData/tests/allInOne.json @@ -0,0 +1,163 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Setup": [ + { "name": "multiEnumExample", "type": "Example" }, + { "name": "basicTypesExample", "type": "BasicTypes" }, + { "name": "nestedExample", "type": "Nested1" }, + { "name": "merkleTreeExample", "type": "merkletree", "contains": "MerkleTreeLeaf" } + ], + "Example": [ + { "name": "someEnum1", "type": "enum", "contains": "EnumA" }, + { "name": "someEnum2", "type": "enum", "contains": "EnumB" } + ], + "EnumA": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128,u128*)" }, + { "name": "Variant 3", "type": "(u128)" } + ], + "EnumB": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128)" } + ], + "BasicTypes": [ + { "name": "n0", "type": "felt" }, + { "name": "n1", "type": "bool" }, + { "name": "n2", "type": "string" }, + { "name": "n3", "type": "selector" }, + { "name": "n4", "type": "u128" }, + { "name": "n5", "type": "i128" }, + { "name": "n6", "type": "ContractAddress" }, + { "name": "n7", "type": "ClassHash" }, + { "name": "n8", "type": "timestamp" }, + { "name": "n9", "type": "shortstring" } + ], + "Nested1": [ + { "name": "n1", "type": "bool*" }, + { "name": "n2", "type": "Nested2" } + ], + "Nested2": [ + { "name": "n1", "type": "i128*" }, + { "name": "n2", "type": "Nested3" } + ], + "Nested3": [ + { "name": "n1", "type": "shortstring*" }, + { "name": "n2", "type": "Nested4" } + ], + "Nested4": [ + { "name": "n1", "type": "TokenAmount*" }, + { "name": "n2", "type": "Nested5" } + ], + "Nested5": [ + { "name": "n1", "type": "NftId*" }, + { "name": "n2", "type": "u256*" } + ], + "MerkleTreeLeaf": [ + { "name": "timestamp", "type": "timestamp" }, + { "name": "block_hash", "type": "felt" } + ] + }, + "primaryType": "Setup", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "multiEnumExample": { + "someEnum1": { + "Variant 2": [2, [0, 1, 34, 8748]] + }, + "someEnum2": { + "Variant 1": [] + } + }, + "basicTypesExample": { + "n0": "0x1a2b3c4d5e6f", + "n1": true, + "n2": "Lorem ipsum alskdj alskdjaslkd sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et alskdj alskdjaslkde magna aliqua.", + "n3": "transfers", + "n4": 101927, + "n5": -12980, + "n6": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004d", + "n7": "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcd", + "n8": 100898790, + "n9": "transfer tokens" + }, + "nestedExample": { + "n1": [true, false], + "n2": { + "n1": [-12980, 12980], + "n2": { + "n1": ["transfer tokens", "transfer nfts"], + "n2": { + "n1": [ + { + "token_address": "0x019d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + "amount": { + "low": "0x1", + "high": "0x0" + } + }, + { + "token_address": "0x029d36570d4e46f48e99674bd3fcc84364ab56b96f7c741b1562b82f9e004dc1", + "amount": { + "low": "0x1234", + "high": "0x0" + } + } + ], + "n2": { + "n1": [ + { + "collection_address": "0x022b14c83d9f25e16a4c73b98f5612d3e7c4590f2a8b369c4d15e70a3b291f41", + "token_id": { + "low": "0x3e8", + "high": "0x0" + } + }, + { + "collection_address": "0x0234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + "token_id": { + "low": "0x3e8", + "high": "0x0" + } + } + ], + "n2": [ + { + "low": "0x3e88956", + "high": "0x0" + }, + { + "low": "0x3e39228", + "high": "0x0" + } + ] + } + } + } + } + }, + "merkleTreeExample": [ + { + "timestamp": 100898790, + "block_hash": "0x1a2b3c446e6f" + }, + { + "timestamp": 100898791, + "block_hash": "0x783c4d5e6f" + }, + { + "timestamp": 100898792, + "block_hash": "0x647b3c4d5e6f" + } + ] + } +} diff --git a/typedData/tests/baseExample.json b/typedData/tests/baseExample.json new file mode 100644 index 00000000..0780ecd9 --- /dev/null +++ b/typedData/tests/baseExample.json @@ -0,0 +1,35 @@ +{ + "types": { + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chainId", "type": "felt" } + ], + "Person": [ + { "name": "name", "type": "felt" }, + { "name": "wallet", "type": "felt" } + ], + "Mail": [ + { "name": "from", "type": "Person" }, + { "name": "to", "type": "Person" }, + { "name": "contents", "type": "felt" } + ] + }, + "primaryType": "Mail", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": 1 + }, + "message": { + "from": { + "name": "Cow", + "wallet": "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826" + }, + "to": { + "name": "Bob", + "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB" + }, + "contents": "Hello, Bob!" + } +} \ No newline at end of file diff --git a/typedData/tests/example_array.json b/typedData/tests/example_array.json new file mode 100644 index 00000000..ba6a7aa9 --- /dev/null +++ b/typedData/tests/example_array.json @@ -0,0 +1,34 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "StarknetDomain", + "version": "1", + "chainId": "SN_MAIN", + "revision" : 1 + }, + "message": { + "Name": "some name", + "Some Array": [1, 2, 3, 4], + "Some Object": { + "Some Selector": "transfer", + "Some Contract Address": "0x0123" + } + } +} \ No newline at end of file diff --git a/typedData/tests/example_baseTypes.json b/typedData/tests/example_baseTypes.json new file mode 100644 index 00000000..db504cad --- /dev/null +++ b/typedData/tests/example_baseTypes.json @@ -0,0 +1,41 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [ + { "name": "n0", "type": "felt" }, + { "name": "n1", "type": "bool" }, + { "name": "n2", "type": "string" }, + { "name": "n3", "type": "selector" }, + { "name": "n4", "type": "u128" }, + { "name": "n5", "type": "i128" }, + { "name": "n6", "type": "ContractAddress" }, + { "name": "n7", "type": "ClassHash" }, + { "name": "n8", "type": "timestamp" }, + { "name": "n9", "type": "shortstring" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "n0": "0x3e8", + "n1": true, + "n2": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", + "n3": "transfer", + "n4": 10, + "n5": -10, + "n6": "0x3e8", + "n7": "0x3e8", + "n8": 1000, + "n9": "transfer" + } +} \ No newline at end of file diff --git a/typedData/tests/example_enum.json b/typedData/tests/example_enum.json new file mode 100644 index 00000000..d9f9a3a7 --- /dev/null +++ b/typedData/tests/example_enum.json @@ -0,0 +1,38 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [ + { "name": "someEnum1", "type": "enum", "contains": "EnumA" }, + { "name": "someEnum2", "type": "enum", "contains": "EnumB" } + ], + "EnumA": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128,u128*)" }, + { "name": "Variant 3", "type": "(u128)" } + ], + "EnumB": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128)" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "someEnum1": { + "Variant 2": [2, [0, 1]] + }, + "someEnum2": { + "Variant 1": [] + } + } +} \ No newline at end of file diff --git a/typedData/tests/example_presetTypes.json b/typedData/tests/example_presetTypes.json new file mode 100644 index 00000000..ed810db5 --- /dev/null +++ b/typedData/tests/example_presetTypes.json @@ -0,0 +1,37 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [ + { "name": "n0", "type": "TokenAmount" }, + { "name": "n1", "type": "NftId" } + ] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "n0": { + "token_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + "amount": { + "low": "0x3e8", + "high": "0x0" + } + }, + "n1": { + "collection_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + "token_id": { + "low": "0x3e8", + "high": "0x0" + } + } + } +} \ No newline at end of file diff --git a/typedData/tests/mail_StructArray.json b/typedData/tests/mail_StructArray.json new file mode 100644 index 00000000..b3035f98 --- /dev/null +++ b/typedData/tests/mail_StructArray.json @@ -0,0 +1,44 @@ +{ + "types": { + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chainId", "type": "felt" } + ], + "Person": [ + { "name": "name", "type": "felt" }, + { "name": "wallet", "type": "felt" } + ], + "Post": [ + { "name": "title", "type": "felt" }, + { "name": "content", "type": "felt" } + ], + "Mail": [ + { "name": "from", "type": "Person" }, + { "name": "to", "type": "Person" }, + { "name": "posts_len", "type": "felt" }, + { "name": "posts", "type": "Post*" } + ] + }, + "primaryType": "Mail", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": 1 + }, + "message": { + "from": { + "name": "Cow", + "wallet": "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826" + }, + "to": { + "name": "Bob", + "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB" + }, + "posts_len": 2, + "posts": [ + { "title": "Greeting", "content": "Hello, Bob!" }, + { "title": "Farewell", "content": "Goodbye, Bob!" } + ] + } +} \ No newline at end of file diff --git a/typedData/tests/session_MerkleTree.json b/typedData/tests/session_MerkleTree.json new file mode 100644 index 00000000..b580db74 --- /dev/null +++ b/typedData/tests/session_MerkleTree.json @@ -0,0 +1,42 @@ +{ + "primaryType": "Session", + "types": { + "Policy": [ + { "name": "contractAddress", "type": "felt" }, + { "name": "selector", "type": "selector" } + ], + "Session": [ + { "name": "key", "type": "felt" }, + { "name": "expires", "type": "felt" }, + { "name": "root", "type": "merkletree", "contains": "Policy" } + ], + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chain_id", "type": "felt" } + ] + }, + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chain_id": 1 + }, + "message": { + "key": "0x0000000000000000000000000000000000000000000000000000000000000000", + "expires": "0x0000000000000000000000000000000000000000000000000000000000000000", + "root": [ + { + "contractAddress": "0x1", + "selector": "transfer" + }, + { + "contractAddress": "0x2", + "selector": "transfer" + }, + { + "contractAddress": "0x3", + "selector": "transfer" + } + ] + } +} \ No newline at end of file diff --git a/typedData/tests/v1Nested.json b/typedData/tests/v1Nested.json new file mode 100644 index 00000000..5a5e30af --- /dev/null +++ b/typedData/tests/v1Nested.json @@ -0,0 +1,66 @@ +{ + "domain": { + "name": "Dappland", + "chainId": "0x534e5f5345504f4c4941", + "version": "1.0.2", + "revision": "1" + }, + "message": { + "MessageId": 345, + "From": { + "Name": "Edmund", + "Address": "0x7e00d496e324876bbc8531f2d9a82bf154d1a04a50218ee74cdd372f75a551a" + }, + "To": { + "Name": "Alice", + "Address": "0x69b49c2cc8b16e80e86bfc5b0614a59aa8c9b601569c7b80dde04d3f3151b79" + }, + "Nft_to_transfer": { + "Collection": "Stupid monkeys", + "Address": "0x69b49c2cc8b16e80e86bfc5b0614a59aa8c9b601569c7b80dde04d3f3151b79", + "Nft_id": 112, + "Negotiated_for": { + "Qty": "18.4569325643", + "Unit": "ETH", + "Token_address": "0x69b49c2cc8b16e80e86bfc5b0614a59aa8c9b601569c7b80dde04d3f3151b79", + "Amount": "0x100243260D270EB00" + } + }, + "Comment1": "Monkey with banana, sunglasses,", + "Comment2": "and red hat.", + "Comment3": "" + }, + "primaryType": "TransferERC721", + "types": { + "Account1": [ + {"name": "Name", "type": "string"}, + {"name": "Address", "type": "felt"} + ], + "Nft": [ + {"name": "Collection", "type": "string"}, + {"name": "Address", "type": "felt"}, + {"name": "Nft_id", "type": "felt"}, + {"name": "Negotiated_for", "type": "Transaction"} + ], + "Transaction": [ + {"name": "Qty", "type": "string"}, + {"name": "Unit", "type": "string"}, + {"name": "Token_address", "type": "felt"}, + {"name": "Amount", "type": "felt"} + ], + "TransferERC721": [ + {"name": "MessageId", "type": "felt"}, + {"name": "From", "type": "Account1"}, + {"name": "To", "type": "Account1"}, + {"name": "Nft_to_transfer", "type": "Nft"}, + {"name": "Comment1", "type": "string"}, + {"name": "Comment2", "type": "string"}, + {"name": "Comment3", "type": "string"} + ], + "StarknetDomain": [ + {"name": "name", "type": "string"}, + {"name": "chainId", "type": "felt"}, + {"name": "version", "type": "string"} + ] + } +} \ No newline at end of file diff --git a/typedData/typedData.go b/typedData/typedData.go new file mode 100644 index 00000000..e56cf372 --- /dev/null +++ b/typedData/typedData.go @@ -0,0 +1,866 @@ +package typedData + +import ( + "bytes" + "encoding/json" + "fmt" + "math/big" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/utils" +) + +type TypedData struct { + Types map[string]TypeDefinition + PrimaryType string + Domain Domain + Message map[string]any + Revision *revision +} + +type Domain struct { + Name string `json:"name"` + Version string `json:"version"` + ChainId string `json:"chainId"` + Revision uint8 `json:"revision,omitempty"` +} + +type TypeDefinition struct { + Name string `json:"-"` + Enconding *felt.Felt `json:"-"` + EncoddingString string `json:"-"` + SingleEncString string `json:"-"` + ReferencedTypesEnc []string `json:"-"` + Parameters []TypeParameter +} + +type TypeParameter struct { + Name string `json:"name"` + Type string `json:"type"` + Contains string `json:"contains,omitempty"` +} + +// NewTypedData creates a new instance of TypedData. +// +// Parameters: +// - types: a slice of TypeDefinition representing the types used in the TypedData. +// - primaryType: a string representing the primary type of the TypedData. +// - domain: a Domain struct representing the domain information of the TypedData. +// - message: a byte slice containing the JSON-encoded message. +// +// Returns: +// - td: a pointer to the newly created TypedData instance. +// - err: an error if any occurred during the creation of the TypedData. +func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, message []byte) (td *TypedData, err error) { + //types + typesMap := make(map[string]TypeDefinition) + for _, typeDef := range types { + typesMap[typeDef.Name] = typeDef + } + + //primary type + if _, ok := typesMap[primaryType]; !ok { + return td, fmt.Errorf("invalid primary type: %s", primaryType) + } + + //message + messageMap := make(map[string]any) + err = json.Unmarshal(message, &messageMap) + if err != nil { + return td, fmt.Errorf("error unmarshalling the message: %w", err) + } + + //revision + revision, err := GetRevision(domain.Revision) + if err != nil { + return td, fmt.Errorf("error getting revision: %w", err) + } + + //domain type encoding + domainTypeDef, err := encodeTypes(revision.Domain(), typesMap, revision) + if err != nil { + return td, err + } + typesMap[revision.Domain()] = domainTypeDef + + //types encoding + primaryTypeDef, err := encodeTypes(primaryType, typesMap, revision) + if err != nil { + return td, err + } + typesMap[primaryType] = primaryTypeDef + + for _, typeDef := range typesMap { + if typeDef.EncoddingString == "" { + return td, fmt.Errorf("'encodeTypes' failed: type '%s' doesn't have encode value", typeDef.Name) + } + } + + td = &TypedData{ + Types: typesMap, + PrimaryType: primaryType, + Domain: domain, + Message: messageMap, + Revision: revision, + } + + return td, nil +} + +// GetMessageHash calculates the hash of a typed message for a given account using the StarkCurve. +// +// (ref: https://github.com/starknet-io/SNIPs/blob/5d5a42c654c27b377d8b7f90b453065fd19ec2eb/SNIPS/snip-12.md#specification) +// +// Parameters: +// - account: A string representing the account. +// Returns: +// - hash: A pointer to a felt.Felt representing the calculated hash. +func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) { + //signed_data = encode(PREFIX_MESSAGE, Enc[domain_separator], account, Enc[message]) + + //PREFIX_MESSAGE + prefixMessage, err := utils.HexToFelt(utils.StrToHex("StarkNet Message")) + if err != nil { + return hash, err + } + + //Enc[domain_separator] + domEnc, err := td.GetStructHash(td.Revision.Domain()) + if err != nil { + return hash, err + } + + //account + accountFelt, err := utils.HexToFelt(account) + if err != nil { + return hash, err + } + + //Enc[message] + msgEnc, err := td.GetStructHash(td.PrimaryType) + if err != nil { + return hash, err + } + + return td.Revision.HashMethod(prefixMessage, domEnc, accountFelt, msgEnc), nil +} + +// GetStructHash calculates the hash of a struct type and its respective data. +// +// Parameters: +// - typeName: the name of the type to be hashed. +// - context: optional context strings to be included in the hash calculation. +// +// You can use 'context' to specify the path of the type you want to hash. Example: if you want to hash the type "ExampleInner" +// that is within the "Example" primary type with the name of "example_inner", you can specify the context as ["example_inner"]. +// If "ExampleInner" has a parameter with the name of "example_inner_inner" that you want to know the hash, you can specify the context +// as ["example_inner", "example_inner_inner"]. +// +// Returns: +// - hash: A pointer to a felt.Felt representing the calculated hash. +// - err: an error if any occurred during the hash calculation. +func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *felt.Felt, err error) { + typeDef, ok := td.Types[typeName] + if !ok { + if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { + return hash, fmt.Errorf("error getting the type definition of %s", typeName) + } + } + encTypeData, err := EncodeData(&typeDef, td, context...) + if err != nil { + return hash, err + } + + return td.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil +} + +// shortGetStructHash is a helper function that calculates the hash of a struct type and its respective data. +func shortGetStructHash( + typeDef *TypeDefinition, + typedData *TypedData, + data map[string]any, + isEnum bool, + context ...string, +) (hash *felt.Felt, err error) { + + encTypeData, err := encodeData(typeDef, typedData, data, isEnum, context...) + if err != nil { + return hash, err + } + + return typedData.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil +} + +// GetTypeHash returns the hash of the given type. +// +// Parameters: +// - typeName: the name of the type to hash +// Returns: +// - hash: A pointer to a felt.Felt representing the calculated hash. +// - err: an error if any occurred during the hash calculation. +func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) { + //TODO: create/update methods descriptions + typeDef, ok := td.Types[typeName] + if !ok { + if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { + return typeDef.Enconding, fmt.Errorf("type '%s' not found", typeName) + } + } + return typeDef.Enconding, nil +} + +// encodeTypes encodes the given type name using the TypedData struct. +// Parameters: +// - typeName: name of the type to encode +// - types: map of type definitions +// - revision: revision information +// - isEnum: optional boolean indicating if type is an enum +// Returns: +// - newTypeDef: the encoded type definition +// - err: any error encountered during encoding +func encodeTypes(typeName string, types map[string]TypeDefinition, revision *revision, isEnum ...bool) (newTypeDef TypeDefinition, err error) { + getTypeEncodeString := func(typeName string, typeDef TypeDefinition, customTypesStringEnc *[]string, isEnum ...bool) (result string, err error) { + verifyTypeName := func(param TypeParameter, isEnum ...bool) error { + singleTypeName, _ := strings.CutSuffix(param.Type, "*") + + if isBasicType(singleTypeName) { + if singleTypeName == "merkletree" { + if param.Contains == "" { + return fmt.Errorf("missing 'contains' value from '%s'", param.Name) + } + newTypeDef, err := encodeTypes(param.Contains, types, revision) + if err != nil { + return err + } + + types[param.Contains] = newTypeDef + } + return nil + } + + if isPresetType(singleTypeName) { + typeEnc, ok := revision.Types().Preset[singleTypeName] + if !ok { + return fmt.Errorf("error trying to get the type definition of '%s'", singleTypeName) + } + *customTypesStringEnc = append(*customTypesStringEnc, append([]string{typeEnc.SingleEncString}, typeEnc.ReferencedTypesEnc...)...) + + return nil + } + + if newTypeDef := types[singleTypeName]; newTypeDef.SingleEncString != "" { + *customTypesStringEnc = append(*customTypesStringEnc, append([]string{newTypeDef.SingleEncString}, newTypeDef.ReferencedTypesEnc...)...) + return nil + } + + newTypeDef, err := encodeTypes(singleTypeName, types, revision, isEnum...) + if err != nil { + return err + } + + *customTypesStringEnc = append(*customTypesStringEnc, append([]string{newTypeDef.SingleEncString}, newTypeDef.ReferencedTypesEnc...)...) + types[singleTypeName] = newTypeDef + + return nil + } + + var buf bytes.Buffer + quotationMark := "" + if revision.Version() == 1 { + quotationMark = `"` + } + + buf.WriteString(quotationMark + typeName + quotationMark) + buf.WriteString("(") + + for i, param := range typeDef.Parameters { + if len(isEnum) != 0 { + reg, err := regexp.Compile(`[^\(\),\s]+`) + if err != nil { + return "", err + } + typesArr := reg.FindAllString(param.Type, -1) + var fullTypeName string + for i, typeNam := range typesArr { + fullTypeName += `"` + typeNam + `"` + if i < (len(typesArr) - 1) { + fullTypeName += `,` + } + } + buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+`(`+"%s"+`)`, param.Name, fullTypeName)) + + for _, typeNam := range typesArr { + err = verifyTypeName(TypeParameter{Type: typeNam}) + if err != nil { + return "", err + } + } + } else { + currentTypeName := param.Type + + if currentTypeName == "enum" { + if param.Contains == "" { + return "", fmt.Errorf("missing 'contains' value from '%s'", param.Name) + } + currentTypeName = param.Contains + err = verifyTypeName(TypeParameter{Type: currentTypeName}, true) + if err != nil { + return "", err + } + } + + buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+quotationMark+"%s"+quotationMark, param.Name, currentTypeName)) + + err = verifyTypeName(param) + if err != nil { + return "", err + } + } + if i != (len(typeDef.Parameters) - 1) { + buf.WriteString(",") + } + } + buf.WriteString(")") + + return buf.String(), nil + } + + typeDef, ok := types[typeName] + if !ok { + return typeDef, fmt.Errorf("can't parse type %s from types %v", typeName, types) + } + + if newTypeDef = types[typeName]; newTypeDef.EncoddingString != "" { + return newTypeDef, nil + } + + referencedTypesEnc := make([]string, 0) + + singleEncString, err := getTypeEncodeString(typeName, typeDef, &referencedTypesEnc, isEnum...) + if err != nil { + return typeDef, err + } + + fullEncString := singleEncString + // appends the custom types' encode + if len(referencedTypesEnc) > 0 { + // temp map just to remove duplicated items + uniqueMap := make(map[string]bool) + for _, typeEncStr := range referencedTypesEnc { + uniqueMap[typeEncStr] = true + } + // clear the array + referencedTypesEnc = make([]string, 0, len(uniqueMap)) + // fill it again + for typeEncStr := range uniqueMap { + referencedTypesEnc = append(referencedTypesEnc, typeEncStr) + } + + slices.Sort(referencedTypesEnc) + + for _, typeEncStr := range referencedTypesEnc { + fullEncString += typeEncStr + } + } + + newTypeDef = TypeDefinition{ + Name: typeDef.Name, + Parameters: typeDef.Parameters, + Enconding: utils.GetSelectorFromNameFelt(fullEncString), + EncoddingString: fullEncString, + SingleEncString: singleEncString, + ReferencedTypesEnc: referencedTypesEnc, + } + + return newTypeDef, nil +} + +// EncodeData encodes the given type definition using the TypedData struct. +func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc []*felt.Felt, err error) { + if typeDef.Name == "StarkNetDomain" || typeDef.Name == "StarknetDomain" { + domainMap := make(map[string]any) + domainBytes, err := json.Marshal(td.Domain) + if err != nil { + return enc, err + } + err = json.Unmarshal(domainBytes, &domainMap) + if err != nil { + return enc, err + } + + // ref: https://community.starknet.io/t/signing-transactions-and-off-chain-messages/66 + domainMap["chain_id"] = domainMap["chainId"] + + return encodeData(typeDef, td, domainMap, false, context...) + } + + return encodeData(typeDef, td, td.Message, false, context...) +} + +// encodeData is a helper function that encodes the given type definition using the TypedData struct. +// +// Parameters: +// - typeDef: a pointer to the TypeDefinition representing the type to be encoded. +// - typedData: a pointer to the TypedData struct containing the data to be encoded. +// - data: a map containing the data to be encoded. +// - isEnum: a boolean indicating whether the type is an enum. +// - context: optional context strings to be included in the encoding process. +// +// The function first checks if the context is provided and updates the data map accordingly. +// It then defines helper functions to handle standard types, object types, and arrays. +// The main encoding logic is implemented within these helper functions. +// +// Returns: +// - enc: a slice of pointers to felt.Felt representing the encoded data. +// - err: an error if any occurred during the encoding process. +func encodeData( + typeDef *TypeDefinition, + typedData *TypedData, + data map[string]any, + isEnum bool, + context ...string, +) (enc []*felt.Felt, err error) { + if len(context) != 0 { + for _, paramName := range context { + value, ok := data[paramName] + if !ok { + return enc, fmt.Errorf("context error: parameter '%s' not found in the data map", paramName) + } + newData, ok := value.(map[string]any) + if !ok { + return enc, fmt.Errorf("context error: error generating the new data map") + } + data = newData + } + } + + // helper functions + verifyType := func(param TypeParameter, data any, isEnum bool) (resp *felt.Felt, err error) { + //helper functions + var handleStandardTypes func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error) + var handleObjectTypes func(typeDef *TypeDefinition, data any, isEnum ...bool) (resp *felt.Felt, err error) + var handleArrays func(param TypeParameter, data any, rev *revision, isMerkle ...bool) (resp *felt.Felt, err error) + + handleStandardTypes = func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error) { + switch param.Type { + case "merkletree": + tempParam := TypeParameter{ + Name: param.Name, + Type: param.Contains, + } + resp, err := handleArrays(tempParam, data, rev, true) + if err != nil { + return resp, err + } + return resp, nil + case "enum": + typeDef, ok := typedData.Types[param.Contains] + if !ok { + return resp, fmt.Errorf("error trying to get the type definition of '%s' in contains of '%s'", param.Contains, param.Name) + } + resp, err := handleObjectTypes(&typeDef, data, true) + if err != nil { + return resp, err + } + return resp, nil + case "NftId", "TokenAmount", "u256": + typeDef, ok := rev.Types().Preset[param.Type] + if !ok { + return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) + } + resp, err := handleObjectTypes(&typeDef, data) + if err != nil { + return resp, err + } + return resp, nil + default: + resp, err := encodePieceOfData(param.Type, data, rev) + if err != nil { + return resp, err + } + return resp, nil + } + } + + handleObjectTypes = func(typeDef *TypeDefinition, data any, isEnum ...bool) (resp *felt.Felt, err error) { + mapData, ok := data.(map[string]any) + if !ok { + return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeDef) + } + + if len(isEnum) != 0 && isEnum[0] { + resp, err = shortGetStructHash(typeDef, typedData, mapData, true) + } else { + resp, err = shortGetStructHash(typeDef, typedData, mapData, false) + } + if err != nil { + return resp, err + } + + return resp, nil + } + + handleArrays = func(param TypeParameter, data any, rev *revision, isMerkle ...bool) (resp *felt.Felt, err error) { + var handleMerkleTree func(felts []*felt.Felt) *felt.Felt + // ref https://github.com/starknet-io/starknet.js/blob/3cfdd8448538128bf9fd158d2e87be20310a69e3/src/utils/merkle.ts#L41 + handleMerkleTree = func(felts []*felt.Felt) *felt.Felt { + if len(felts) == 1 { + return felts[0] + } + var localArr []*felt.Felt + + for i := 0; i < len(felts); i += 2 { + if i+1 == len(felts) { + localArr = append(localArr, rev.HashMerkleMethod(felts[i], new(felt.Felt))) + } else { + localArr = append(localArr, rev.HashMerkleMethod(felts[i], felts[i+1])) + } + } + + return handleMerkleTree(localArr) + } + + dataArray, ok := data.([]any) + if !ok { + return resp, fmt.Errorf("error trying to convert the value of '%s' to an array", param.Name) + } + localEncode := []*felt.Felt{} + singleParamType, _ := strings.CutSuffix(param.Type, "*") + + if isBasicType(singleParamType) { + for _, item := range dataArray { + resp, err := handleStandardTypes(TypeParameter{Name: param.Name, Type: singleParamType, Contains: param.Contains}, item, rev) + if err != nil { + return resp, err + } + localEncode = append(localEncode, resp) + } + return rev.HashMethod(localEncode...), nil + } + + var typeDef TypeDefinition + if isPresetType(singleParamType) { + typeDef, ok = rev.Types().Preset[singleParamType] + } else { + typeDef, ok = typedData.Types[singleParamType] + } + if !ok { + return resp, fmt.Errorf("error trying to get the type definition of '%s'", singleParamType) + } + + for _, item := range dataArray { + resp, err := handleObjectTypes(&typeDef, item, isEnum) + if err != nil { + return resp, err + } + localEncode = append(localEncode, resp) + } + + if len(isMerkle) != 0 { + return handleMerkleTree(localEncode), nil + } + return rev.HashMethod(localEncode...), nil + } + + //function logic + if strings.HasSuffix(param.Type, "*") { + resp, err := handleArrays(param, data, typedData.Revision) + if err != nil { + return resp, err + } + return resp, nil + } + + if isStandardType(param.Type) { + resp, err := handleStandardTypes(param, data, typedData.Revision) + if err != nil { + return resp, err + } + return resp, nil + } + + nextTypeDef, ok := typedData.Types[param.Type] + if !ok { + return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) + } + resp, err = handleObjectTypes(&nextTypeDef, data, isEnum) + if err != nil { + return resp, err + } + return resp, nil + } + + getData := func(key string) (any, error) { + value, ok := data[key] + if !ok { + return value, fmt.Errorf("error trying to get the value of the '%s' param", key) + } + return value, nil + } + + // function logic + for paramIndex, param := range typeDef.Parameters { + if isEnum { + value, exists := data[param.Name] + // check if it's the selected enum option + if !exists { + if paramIndex == len(typeDef.Parameters)-1 { + return enc, fmt.Errorf("no enum option selected for '%s', the data is not valid", typeDef.Name) + } + continue + } + + dataArr, ok := value.([]any) + if !ok { + return enc, fmt.Errorf("error trying to convert the data value of '%s' to an array", param.Name) + } + + enc = append(enc, new(felt.Felt).SetUint64(uint64(paramIndex))) + + if len(dataArr) == 0 { + enc = append(enc, &felt.Zero) + break + } + + reg := regexp.MustCompile(`[^\(\),\s]+`) + typesArr := reg.FindAllString(param.Type, -1) + + for i, typeNam := range typesArr { + resp, err := verifyType(TypeParameter{Type: typeNam}, dataArr[i], false) + if err != nil { + return enc, err + } + enc = append(enc, resp) + } + + break + } + + localData, err := getData(param.Name) + if err != nil { + return enc, err + } + + resp, err := verifyType(param, localData, false) + if err != nil { + return enc, err + } + enc = append(enc, resp) + } + + return enc, nil +} + +// encodePieceOfData encodes a single piece of data based on its type. +// Parameters: +// - typeName: the type of data to encode +// - data: the actual data to encode +// - rev: revision information +// Returns: +// - resp: encoded data as a felt.Felt +// - err: any error encountered during encoding +func encodePieceOfData(typeName string, data any, rev *revision) (resp *felt.Felt, err error) { + getFeltFromData := func() (feltValue *felt.Felt, err error) { + strValue := func(data any) string { + switch v := data.(type) { + case string: + return v + case float64: + // Handle floating point numbers without trailing zeros + if float64(int64(v)) == v { + return strconv.FormatInt(int64(v), 10) + } + return strconv.FormatFloat(v, 'f', -1, 64) + case float32: + if float32(int32(v)) == v { + return strconv.FormatInt(int64(v), 10) + } + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case int: + return strconv.Itoa(v) + case int64: + return strconv.FormatInt(v, 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case bool: + return strconv.FormatBool(v) + case nil: + return "" + default: + return fmt.Sprintf("%v", v) + } + }(data) + hexValue := utils.StrToHex(strValue) + feltValue, err = utils.HexToFelt(hexValue) + if err != nil { + return feltValue, err + } + + return feltValue, nil + } + + switch typeName { + case "felt", "shortstring", "u128", "ContractAddress", "ClassHash", "timestamp": + resp, err = getFeltFromData() + if err != nil { + return resp, err + } + return resp, nil + case "bool": + boolVal, ok := data.(bool) + if !ok { + return resp, fmt.Errorf("faild to convert '%v' to 'bool'", data) + } + if boolVal { + return new(felt.Felt).SetUint64(1), nil + } + return new(felt.Felt).SetUint64(0), nil + case "i128": + strValue := fmt.Sprintf("%v", data) + bigNum, ok := new(big.Int).SetString(strValue, 0) + if !ok { + return resp, fmt.Errorf("faild to convert '%s' of type 'i128' to big.Int", strValue) + } + feltValue := new(felt.Felt).SetBigInt(bigNum) + return feltValue, nil + case "string": + if rev.Version() == 0 { + resp, err := getFeltFromData() + if err != nil { + return resp, err + } + return resp, nil + } else { + value := fmt.Sprintf("%v", data) + byteArr, err := utils.StringToByteArrFelt(value) + if err != nil { + return resp, err + } + return rev.HashMethod(byteArr...), nil + } + case "selector": + value := fmt.Sprintf("%v", data) + return utils.GetSelectorFromNameFelt(value), nil + default: + return resp, fmt.Errorf("invalid type '%s'", typeName) + } +} + +// UnmarshalJSON implements the json.Unmarshaler interface for TypedData +func (td *TypedData) UnmarshalJSON(data []byte) error { + var dec map[string]json.RawMessage + if err := json.Unmarshal(data, &dec); err != nil { + return err + } + + // primaryType + primaryType, err := utils.GetAndUnmarshalJSONFromMap[string](dec, "primaryType") + if err != nil { + return err + } + + // domain + domain, err := utils.GetAndUnmarshalJSONFromMap[Domain](dec, "domain") + if err != nil { + return err + } + + // types + rawTypes, err := utils.GetAndUnmarshalJSONFromMap[map[string]json.RawMessage](dec, "types") + if err != nil { + return err + } + var types []TypeDefinition + for key, value := range rawTypes { + var params []TypeParameter + if err := json.Unmarshal(value, ¶ms); err != nil { + return err + } + + typeDef := TypeDefinition{ + Name: key, + Parameters: params, + } + + types = append(types, typeDef) + } + + // message + rawMessage, ok := dec["message"] + if !ok { + return fmt.Errorf("invalid typedData json: missing field 'message'") + } + bytesMessage, err := json.Marshal(rawMessage) + if err != nil { + return err + } + + // result + resultTypedData, err := NewTypedData(types, primaryType, domain, bytesMessage) + if err != nil { + return err + } + + *td = *resultTypedData + return nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Domain +func (domain *Domain) UnmarshalJSON(data []byte) error { + var dec map[string]any + if err := json.Unmarshal(data, &dec); err != nil { + return err + } + + getField := func(fieldName string) (string, error) { + value, ok := dec[fieldName] + if !ok { + return "", fmt.Errorf("error getting the value of '%s' from 'domain' struct", fieldName) + } + return fmt.Sprintf("%v", value), nil + } + + name, err := getField("name") + if err != nil { + return err + } + + version, err := getField("version") + if err != nil { + return err + } + + revision, err := getField("revision") + if err != nil { + revision = "0" + } + numRevision, err := strconv.ParseUint(revision, 10, 8) + if err != nil { + return err + } + + chainId, err := getField("chainId") + if err != nil { + if numRevision == 1 { + return err + } + var err2 error + // ref: https://community.starknet.io/t/signing-transactions-and-off-chain-messages/66 + chainId, err2 = getField("chain_id") + if err2 != nil { + return fmt.Errorf("%w: %w", err, err2) + } + } + + *domain = Domain{ + Name: name, + Version: version, + ChainId: chainId, + Revision: uint8(numRevision), + } + return nil +} diff --git a/typedData/typedData_test.go b/typedData/typedData_test.go new file mode 100644 index 00000000..318e9c62 --- /dev/null +++ b/typedData/typedData_test.go @@ -0,0 +1,346 @@ +package typedData + +import ( + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +var typedDataExamples = make(map[string]TypedData) + +// TestMain initializes test data by loading TypedData examples from JSON files. +// It reads multiple test files and stores them in the typedDataExamples map +// before running the tests. +// +// Parameters: +// - m: The testing.M object that provides the test runner +// Returns: +// - None (calls os.Exit directly) +func TestMain(m *testing.M) { + fileNames := []string{ + "baseExample", + "example_array", + "example_baseTypes", + "example_enum", + "example_presetTypes", + "mail_StructArray", + "session_MerkleTree", + "v1Nested", + "allInOne", + } + + for _, fileName := range fileNames { + var ttd TypedData + content, err := os.ReadFile(fmt.Sprintf("./tests/%s.json", fileName)) + if err != nil { + panic(fmt.Errorf("fail to read file: %w", err)) + } + err = json.Unmarshal(content, &ttd) + if err != nil { + panic(fmt.Errorf("fail to unmarshal TypedData: %w", err)) + } + + typedDataExamples[fileName] = ttd + } + + os.Exit(m.Run()) +} + +// BMockTypedData is a helper function for benchmarks that loads a base example +// TypedData from a JSON file. +// +// Parameters: +// - b: The testing.B object used for benchmarking +// Returns: +// - ttd: A TypedData instance loaded from the base example file +func BMockTypedData(b *testing.B) (ttd TypedData) { + b.Helper() + content, err := os.ReadFile("./tests/baseExample.json") + require.NoError(b, err) + + err = json.Unmarshal(content, &ttd) + require.NoError(b, err) + + return +} + +// TestMessageHash tests the GetMessageHash function. +// +// It creates a mock TypedData and sets up a test case for hashing a mail message. +// The mail message contains information about the sender and recipient, as well as the contents of the message. +// The function then calls the GetMessageHash function with the necessary parameters to calculate the message hash. +// If an error occurs during the hashing process, an error is reported using the t.Errorf function. +// The expected hash value is compared with the actual hash value returned by the function. +// If the values do not match, an error is reported using the t.Errorf function. +// +// Parameters: +// - t: a testing.T object that provides methods for testing functions +// Returns: +// - None +func TestGetMessageHash(t *testing.T) { + type testSetType struct { + TypedData TypedData + Address string + ExpectedMessageHash string + } + testSet := []testSetType{ + { + TypedData: typedDataExamples["baseExample"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x6fcff244f63e38b9d88b9e3378d44757710d1b244282b435cb472053c8d78d0", + }, + { + TypedData: typedDataExamples["example_array"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x88edea26d6177a8bc545b2e73c960ab7ddd67b46237b386b514e50315ce0f4", + }, + { + TypedData: typedDataExamples["example_baseTypes"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0xdb7829db8909c0c5496f5952bcfc4fc894341ce01842537fc4f448743480b6", + }, + { + TypedData: typedDataExamples["example_presetTypes"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x185b339d5c566a883561a88fb36da301051e2c0225deb325c91bb7aa2f3473a", + }, + { + TypedData: typedDataExamples["session_MerkleTree"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x751fb7d98545f7649d0d0eadc80d770fcd88d8cfaa55590b284f4e1b701ef0a", + }, + { + TypedData: typedDataExamples["mail_StructArray"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x5914ed2764eca2e6a41eb037feefd3d2e33d9af6225a9e7fe31ac943ff712c", + }, + { + TypedData: typedDataExamples["v1Nested"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x69b57cf0cd7c151c51f9616cc58a1f0a877fec28c8c15ff7537cf777c54a30d", + }, + { + TypedData: typedDataExamples["example_enum"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x416b85b18063b1b3420ab709e9d5e35cb716691d397c5841ce7c5198ee30bf", + }, + { + TypedData: typedDataExamples["allInOne"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x300dc63cc85a15529bba5ed482009be716645fa9f2c64bd8716cf6f34767651", + }, + } + + for _, test := range testSet { + hash, err := test.TypedData.GetMessageHash(test.Address) + require.NoError(t, err) + + require.Equal(t, test.ExpectedMessageHash, hash.String()) + } +} + +// BenchmarkGetMessageHash is a benchmark function for testing the GetMessageHash function. +// +// It tests the performance of the GetMessageHash function by running it with different input sizes. +// The input size is determined by the bit length of the address parameter, which is converted from +// a hexadecimal string to a big integer using the HexToBN function from the utils package. +// +// Parameters: +// - b: a testing.B object that provides methods for benchmarking the function +// Returns: +// +// none +func BenchmarkGetMessageHash(b *testing.B) { + ttd := BMockTypedData(b) + + addr := "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826" + b.Run(fmt.Sprintf("input_size_%d", len(addr)), func(b *testing.B) { + result, err := ttd.GetMessageHash(addr) + require.NoError(b, err) + require.NotEmpty(b, result) + }) +} + +// TestGeneral_GetTypeHash tests the GetTypeHash function. +// +// It tests the GetTypeHash function by calling it with different input values +// and comparing the result with expected values. It also checks that the +// encoding of the types matches the expected values. +// +// Parameters: +// - t: The testing.T object used for reporting test failures and logging test output +// Returns: +// +// none +func TestGetTypeHash(t *testing.T) { + type testSetType struct { + TypedData TypedData + TypeName string + ExpectedHash string + } + testSet := []testSetType{ + { + TypedData: typedDataExamples["baseExample"], + TypeName: "StarkNetDomain", + ExpectedHash: "0x1bfc207425a47a5dfa1a50a4f5241203f50624ca5fdf5e18755765416b8e288", + }, + { + TypedData: typedDataExamples["baseExample"], + TypeName: "Mail", + ExpectedHash: "0x13d89452df9512bf750f539ba3001b945576243288137ddb6c788457d4b2f79", + }, + { + TypedData: typedDataExamples["example_baseTypes"], + TypeName: "Example", + ExpectedHash: "0x1f94cd0be8b4097a41486170fdf09a4cd23aefbc74bb2344718562994c2c111", + }, + { + TypedData: typedDataExamples["example_presetTypes"], + TypeName: "Example", + ExpectedHash: "0x1a25a8bb84b761090b1fadaebe762c4b679b0d8883d2bedda695ea340839a55", + }, + { + TypedData: typedDataExamples["session_MerkleTree"], + TypeName: "Session", + ExpectedHash: "0x1aa0e1c56b45cf06a54534fa1707c54e520b842feb21d03b7deddb6f1e340c", + }, + } + for _, test := range testSet { + hash, err := test.TypedData.GetTypeHash(test.TypeName) + require.NoError(t, err) + + require.Equal(t, test.ExpectedHash, hash.String()) + } +} + +// TestEncodeType tests the EncodeType function. +// +// It creates a mock typed data and calls the EncodeType method with the +// type name. It checks if the returned encoding matches the expected +// encoding. If there is an error during the encoding process, it fails the +// test. +// +// Parameters: +// - t: The testing.T object used for reporting test failures and logging test output +// Returns: +// +// none +func TestEncodeType(t *testing.T) { + type testSetType struct { + TypedData TypedData + TypeName string + ExpectedEncode string + } + testSet := []testSetType{ + { + TypedData: typedDataExamples["baseExample"], + TypeName: "StarkNetDomain", + ExpectedEncode: "StarkNetDomain(name:felt,version:felt,chainId:felt)", + }, + { + TypedData: typedDataExamples["baseExample"], + TypeName: "Mail", + ExpectedEncode: "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)", + }, + { + TypedData: typedDataExamples["example_array"], + TypeName: "StarknetDomain", + ExpectedEncode: `"StarknetDomain"("name":"shortstring","version":"shortstring","chainId":"shortstring","revision":"shortstring")`, + }, + { + TypedData: typedDataExamples["example_baseTypes"], + TypeName: "Example", + ExpectedEncode: `"Example"("n0":"felt","n1":"bool","n2":"string","n3":"selector","n4":"u128","n5":"i128","n6":"ContractAddress","n7":"ClassHash","n8":"timestamp","n9":"shortstring")`, + }, + { + TypedData: typedDataExamples["example_presetTypes"], + TypeName: "Example", + ExpectedEncode: `"Example"("n0":"TokenAmount","n1":"NftId")"NftId"("collection_address":"ContractAddress","token_id":"u256")"TokenAmount"("token_address":"ContractAddress","amount":"u256")"u256"("low":"u128","high":"u128")`, + }, + { + TypedData: typedDataExamples["session_MerkleTree"], + TypeName: "Session", + ExpectedEncode: `Session(key:felt,expires:felt,root:merkletree)`, + }, + { + TypedData: typedDataExamples["mail_StructArray"], + TypeName: "Mail", + ExpectedEncode: `Mail(from:Person,to:Person,posts_len:felt,posts:Post*)Person(name:felt,wallet:felt)Post(title:felt,content:felt)`, + }, + { + TypedData: typedDataExamples["v1Nested"], + TypeName: "TransferERC721", + ExpectedEncode: `"TransferERC721"("MessageId":"felt","From":"Account1","To":"Account1","Nft_to_transfer":"Nft","Comment1":"string","Comment2":"string","Comment3":"string")"Account1"("Name":"string","Address":"felt")"Nft"("Collection":"string","Address":"felt","Nft_id":"felt","Negotiated_for":"Transaction")"Transaction"("Qty":"string","Unit":"string","Token_address":"felt","Amount":"felt")`, + }, + { + TypedData: typedDataExamples["example_enum"], + TypeName: "Example", + ExpectedEncode: `"Example"("someEnum1":"EnumA","someEnum2":"EnumB")"EnumA"("Variant 1":(),"Variant 2":("u128","u128*"),"Variant 3":("u128"))"EnumB"("Variant 1":(),"Variant 2":("u128"))`, + }, + } + for _, test := range testSet { + require.Equal(t, test.ExpectedEncode, test.TypedData.Types[test.TypeName].EncoddingString) + } +} + +// TestGetStructHash tests the GetStructHash function. +// +// It creates a mock typed data and calls the GetStructHash method with the +// type name. It checks if the returned encoding matches the expected +// encoding. If there is an error during the encoding process, it fails the +// test. +// +// Parameters: +// - t: The testing.T object used for reporting test failures and logging test output +// Returns: +// +// none +func TestGetStructHash(t *testing.T) { + type testSetType struct { + TypedData TypedData + TypeName string + Context []string + ExpectedHash string + } + testSet := []testSetType{ + { + TypedData: typedDataExamples["baseExample"], + TypeName: "StarkNetDomain", + ExpectedHash: "0x54833b121883a3e3aebff48ec08a962f5742e5f7b973469c1f8f4f55d470b07", + }, + { + TypedData: typedDataExamples["example_baseTypes"], + TypeName: "Example", + ExpectedHash: "0x75db031c1f5bf980cc48f46943b236cb85a95c8f3b3c8203572453075d3d39", + }, + { + TypedData: typedDataExamples["example_presetTypes"], + TypeName: "Example", + ExpectedHash: "0x74fba3f77f8a6111a9315bac313bf75ecfa46d1234e0fda60312fb6a6517667", + }, + { + TypedData: typedDataExamples["session_MerkleTree"], + TypeName: "Session", + ExpectedHash: "0x73602062421caf6ad2e942253debfad4584bff58930981364dcd378021defe8", + }, + { + TypedData: typedDataExamples["v1Nested"], + TypeName: "TransferERC721", + ExpectedHash: "0x11b5fb80dd88c3d8b6239b065def4ac9a79e6995b117ed5940a3a0734324b79", + }, + { + TypedData: typedDataExamples["example_enum"], + TypeName: "Example", + ExpectedHash: "0x1551dc992033e2256a2f7ec849495d90f9759ebb535e3006d16e2b9e3b57b4c", + }, + } + for _, test := range testSet { + hash, err := test.TypedData.GetStructHash(test.TypeName, test.Context...) + require.NoError(t, err) + + require.Equal(t, test.ExpectedHash, hash.String()) + } +} diff --git a/utils/Felt.go b/utils/Felt.go index 002b7cba..37343c10 100644 --- a/utils/Felt.go +++ b/utils/Felt.go @@ -92,6 +92,20 @@ func FeltArrToBigIntArr(f []*felt.Felt) []*big.Int { return bigArr } +// FeltArrToStringArr converts an array of Felt objects to an array of string objects. +// +// Parameters: +// - f: the array of Felt objects to convert +// Returns: +// - []string: the array of string objects +func FeltArrToStringArr(f []*felt.Felt) []string { + stringArr := make([]string, len(f)) + for i, felt := range f { + stringArr[i] = felt.String() + } + return stringArr +} + // StringToByteArrFelt converts string to array of Felt objects. // The returned array of felts will be of the format // @@ -117,7 +131,7 @@ func StringToByteArrFelt(s string) ([]*felt.Felt, error) { arr := r.FindAllString(s, -1) if len(arr) == 0 { - return []*felt.Felt{}, fmt.Errorf("invalid string no matches found, s: %s", s) + return []*felt.Felt{&felt.Zero, &felt.Zero, &felt.Zero}, nil } hexarr := []string{} diff --git a/utils/data.go b/utils/data.go index ce5d2098..f813aa84 100644 --- a/utils/data.go +++ b/utils/data.go @@ -1,6 +1,9 @@ package utils -import "encoding/json" +import ( + "encoding/json" + "fmt" +) func UnwrapJSON(data map[string]interface{}, tag string) (map[string]interface{}, error) { if data[tag] != nil { @@ -16,3 +19,17 @@ func UnwrapJSON(data map[string]interface{}, tag string) (map[string]interface{} } return data, nil } + +func GetAndUnmarshalJSONFromMap[T any](aMap map[string]json.RawMessage, key string) (result T, err error) { + value, ok := aMap[key] + if !ok { + return result, fmt.Errorf("invalid json: missing field %s", key) + } + + err = json.Unmarshal(value, &result) + if err != nil { + return result, err + } + + return result, nil +} diff --git a/utils/keccak.go b/utils/keccak.go index 1133df96..79553b62 100644 --- a/utils/keccak.go +++ b/utils/keccak.go @@ -46,6 +46,24 @@ func StrToBig(str string) *big.Int { return b } +// StrToBig generates a hexadecimal from a string/number representation. +// +// Parameters: +// - str: The string to convert to a hexadecimal +// Returns: +// - hex: a string representing the converted value +func StrToHex(str string) string { + if strings.HasPrefix(str, "0x") { + return str + } + + if bigNum, ok := new(big.Int).SetString(str, 0); ok { + return "0x" + bigNum.Text(16) + } + + return "0x" + fmt.Sprintf("%x", str) +} + // HexToShortStr converts a hexadecimal string to a short string (Starknet) representation. // // Parameters: diff --git a/utils/keccak_test.go b/utils/keccak_test.go new file mode 100644 index 00000000..8395f001 --- /dev/null +++ b/utils/keccak_test.go @@ -0,0 +1,30 @@ +package utils + +import "testing" + +// TestGetSelectorFromName tests the GetSelectorFromName function. +// +// It checks if the GetSelectorFromName function returns the expected values +// for different input names. +// The expected values are hard-coded and compared against the actual values. +// If any of the actual values do not match the expected values, an error is +// reported. +// +// Parameters: +// - t: The testing.T object used for reporting test failures and logging test output +// Returns: +// +// none +func TestGetSelectorFromName(t *testing.T) { + sel1 := BigToHex(GetSelectorFromName("initialize")) + sel2 := BigToHex(GetSelectorFromName("mint")) + sel3 := BigToHex(GetSelectorFromName("test")) + + exp1 := "0x79dc0da7c54b95f10aa182ad0a46400db63156920adb65eca2654c0945a463" + exp2 := "0x2f0b3c5710379609eb5495f1ecd348cb28167711b73609fe565a72734550354" + exp3 := "0x22ff5f21f0b81b113e63f7db6da94fedef11b2119b4088b89664fb9a3cb658" + + if sel1 != exp1 || sel2 != exp2 || sel3 != exp3 { + t.Errorf("invalid Keccak256 encoding: %v %v %v\n", sel1, sel2, sel3) + } +}