diff --git a/crypter.go b/crypter.go index a38632dd..b3bdaec8 100644 --- a/crypter.go +++ b/crypter.go @@ -19,6 +19,7 @@ package jose import ( "crypto/ecdsa" "crypto/rsa" + "errors" "fmt" "reflect" ) @@ -292,10 +293,16 @@ func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JsonWe return obj, nil } -// Decrypt and validate the object and return the plaintext. +// Decrypt and validate the object and return the plaintext. Note that this +// function does not support multi-recipient, if you desire multi-recipient +// decryption use DecryptMulti instead. func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) { headers := obj.mergedHeaders(nil) + if len(obj.recipients) > 1 { + return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one") + } + if len(headers.Crit) > 0 { return nil, fmt.Errorf("square/go-jose: unsupported crit header") } @@ -323,7 +330,65 @@ func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) authData := obj.computeAuthData() var plaintext []byte - for _, recipient := range obj.recipients { + recipient := obj.recipients[0] + recipientHeaders := obj.mergedHeaders(&recipient) + + cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) + if err == nil { + // Found a valid CEK -- let's try to decrypt. + plaintext, err = cipher.decrypt(cek, authData, parts) + } + + if plaintext == nil { + return nil, ErrCryptoFailure + } + + // The "zip" header parameter may only be present in the protected header. + if obj.protected.Zip != "" { + plaintext, err = decompress(obj.protected.Zip, plaintext) + } + + return plaintext, err +} + +// DecryptMulti decrypts and validates the object and returns the plaintexts, +// with support for multiple recipients. It returns the index of the recipient +// for which the decryption was successful, the merged headers for that recipient, +// and the plaintext. +func (obj JsonWebEncryption) DecryptMulti(decryptionKey interface{}) (int, JoseHeader, []byte, error) { + globalHeaders := obj.mergedHeaders(nil) + + if len(globalHeaders.Crit) > 0 { + return -1, JoseHeader{}, nil, fmt.Errorf("square/go-jose: unsupported crit header") + } + + decrypter, err := newDecrypter(decryptionKey) + if err != nil { + return -1, JoseHeader{}, nil, err + } + + cipher := getContentCipher(globalHeaders.Enc) + if cipher == nil { + return -1, JoseHeader{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(globalHeaders.Enc)) + } + + generator := randomKeyGenerator{ + size: cipher.keySize(), + } + + parts := &aeadParts{ + iv: obj.iv, + ciphertext: obj.ciphertext, + tag: obj.tag, + } + + authData := obj.computeAuthData() + + index := -1 + var plaintext []byte + var headers rawHeader + + for i, recipient := range obj.recipients { recipientHeaders := obj.mergedHeaders(&recipient) cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator) @@ -331,13 +396,15 @@ func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) // Found a valid CEK -- let's try to decrypt. plaintext, err = cipher.decrypt(cek, authData, parts) if err == nil { + index = i + headers = recipientHeaders break } } } - if plaintext == nil { - return nil, ErrCryptoFailure + if plaintext == nil || err != nil { + return -1, JoseHeader{}, nil, ErrCryptoFailure } // The "zip" header parameter may only be present in the protected header. @@ -345,5 +412,5 @@ func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) plaintext, err = decompress(obj.protected.Zip, plaintext) } - return plaintext, err + return index, headers.sanitized(), plaintext, err } diff --git a/crypter_test.go b/crypter_test.go index 86b8fc0a..431f6537 100644 --- a/crypter_test.go +++ b/crypter_test.go @@ -272,7 +272,7 @@ func TestMultiRecipientJWE(t *testing.T) { err = enc.AddRecipient(RSA_OAEP, &rsaTestKey.PublicKey) if err != nil { - t.Error("error when adding RSA recipient", err) + t.Fatal("error when adding RSA recipient", err) } sharedKey := []byte{ @@ -282,45 +282,46 @@ func TestMultiRecipientJWE(t *testing.T) { err = enc.AddRecipient(A256GCMKW, sharedKey) if err != nil { - t.Error("error when adding AES recipient: ", err) - return + t.Fatal("error when adding AES recipient: ", err) } input := []byte("Lorem ipsum dolor sit amet") obj, err := enc.Encrypt(input) if err != nil { - t.Error("error in encrypt: ", err) - return + t.Fatal("error in encrypt: ", err) } msg := obj.FullSerialize() parsed, err := ParseEncrypted(msg) if err != nil { - t.Error("error in parse: ", err) - return + t.Fatal("error in parse: ", err) } - output, err := parsed.Decrypt(rsaTestKey) + i, _, output, err := parsed.DecryptMulti(rsaTestKey) if err != nil { - t.Error("error on decrypt with RSA: ", err) - return + t.Fatal("error on decrypt with RSA: ", err) + } + + if i != 0 { + t.Fatal("recipient index should be 0 for RSA key") } if bytes.Compare(input, output) != 0 { - t.Error("Decrypted output does not match input: ", output, input) - return + t.Fatal("Decrypted output does not match input: ", output, input) } - output, err = parsed.Decrypt(sharedKey) + i, _, output, err = parsed.DecryptMulti(sharedKey) if err != nil { - t.Error("error on decrypt with AES: ", err) - return + t.Fatal("error on decrypt with AES: ", err) + } + + if i != 1 { + t.Fatal("recipient index should be 1 for shared key") } if bytes.Compare(input, output) != 0 { - t.Error("Decrypted output does not match input", output, input) - return + t.Fatal("Decrypted output does not match input", output, input) } } diff --git a/jws.go b/jws.go index 4b60bd29..55014d17 100644 --- a/jws.go +++ b/jws.go @@ -41,7 +41,10 @@ type rawSignatureInfo struct { // JsonWebSignature represents a signed JWS object after parsing. type JsonWebSignature struct { - payload []byte + payload []byte + // Signatures attached to this object (may be more than one for multi-sig). + // Be careful about accessing these directly, prefer to use Verify() or + // VerifyMulti() to ensure that the data you're getting is verified. Signatures []Signature } diff --git a/signing.go b/signing.go index c6ed2c92..2b338e65 100644 --- a/signing.go +++ b/signing.go @@ -19,6 +19,7 @@ package jose import ( "crypto/ecdsa" "crypto/rsa" + "errors" "fmt" ) @@ -193,13 +194,46 @@ func (ctx *genericSigner) SetEmbedJwk(embed bool) { } // Verify validates the signature on the object and returns the payload. +// Note that this function does not support multi-signature, if you desire +// multi-sig verification use VerifyMulti instead. func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error) { verifier, err := newVerifier(verificationKey) if err != nil { return nil, err } - for _, signature := range obj.Signatures { + if len(obj.Signatures) > 1 { + return nil, errors.New("square/go-jose: too many signatures in payload; expecting only one") + } + + signature := obj.Signatures[0] + headers := signature.mergedHeaders() + if len(headers.Crit) > 0 { + // Unsupported crit header + return nil, ErrCryptoFailure + } + + input := obj.computeAuthData(&signature) + alg := SignatureAlgorithm(headers.Alg) + err = verifier.verifyPayload(input, signature.Signature, alg) + if err == nil { + return obj.payload, nil + } + + return nil, ErrCryptoFailure +} + +// VerifyMulti validates (one of the multiple) signatures on the object and +// returns the index of the signature that was verified, along with the signature +// object and the payload. We return the signature and index to guarantee that +// callers are getting the verified value. +func (obj JsonWebSignature) VerifyMulti(verificationKey interface{}) (int, Signature, []byte, error) { + verifier, err := newVerifier(verificationKey) + if err != nil { + return -1, Signature{}, nil, err + } + + for i, signature := range obj.Signatures { headers := signature.mergedHeaders() if len(headers.Crit) > 0 { // Unsupported crit header @@ -210,9 +244,9 @@ func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error) alg := SignatureAlgorithm(headers.Alg) err := verifier.verifyPayload(input, signature.Signature, alg) if err == nil { - return obj.payload, nil + return i, signature, obj.payload, nil } } - return nil, ErrCryptoFailure + return -1, Signature{}, nil, ErrCryptoFailure } diff --git a/signing_test.go b/signing_test.go index 2cbdf54f..a4d637cb 100644 --- a/signing_test.go +++ b/signing_test.go @@ -224,43 +224,45 @@ func TestMultiRecipientJWS(t *testing.T) { input := []byte("Lorem ipsum dolor sit amet") obj, err := signer.Sign(input) if err != nil { - t.Error("error on sign: ", err) - return + t.Fatal("error on sign: ", err) } _, err = obj.CompactSerialize() if err == nil { - t.Error("message with multiple recipient was compact serialized") + t.Fatal("message with multiple recipient was compact serialized") } msg := obj.FullSerialize() obj, err = ParseSigned(msg) if err != nil { - t.Error("error on parse: ", err) - return + t.Fatal("error on parse: ", err) } - output, err := obj.Verify(&rsaTestKey.PublicKey) + i, _, output, err := obj.VerifyMulti(&rsaTestKey.PublicKey) if err != nil { - t.Error("error on verify: ", err) - return + t.Fatal("error on verify: ", err) + } + + if i != 0 { + t.Fatal("signature index should be 0 for RSA key") } if bytes.Compare(output, input) != 0 { - t.Error("input/output do not match", output, input) - return + t.Fatal("input/output do not match", output, input) } - output, err = obj.Verify(sharedKey) + i, _, output, err = obj.VerifyMulti(sharedKey) if err != nil { - t.Error("error on verify: ", err) - return + t.Fatal("error on verify: ", err) + } + + if i != 1 { + t.Fatal("signature index should be 1 for EC key") } if bytes.Compare(output, input) != 0 { - t.Error("input/output do not match", output, input) - return + t.Fatal("input/output do not match", output, input) } }