Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Improve multi-recipient/multi-sig handling #111

Merged
merged 1 commit into from
Sep 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions crypter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
"reflect"
)
Expand Down Expand Up @@ -292,10 +293,16 @@ func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JsonWe
return obj, nil
}

// Decrypt and validate the object and return the plaintext.
// Decrypt and validate the object and return the plaintext. Note that this
// function does not support multi-recipient, if you desire multi-recipient
// decryption use DecryptMulti instead.
func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
headers := obj.mergedHeaders(nil)

if len(obj.recipients) > 1 {
return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
}

if len(headers.Crit) > 0 {
return nil, fmt.Errorf("square/go-jose: unsupported crit header")
}
Expand Down Expand Up @@ -323,27 +330,87 @@ func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error)
authData := obj.computeAuthData()

var plaintext []byte
for _, recipient := range obj.recipients {
recipient := obj.recipients[0]
recipientHeaders := obj.mergedHeaders(&recipient)

cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
}

if plaintext == nil {
return nil, ErrCryptoFailure
}

// The "zip" header parameter may only be present in the protected header.
if obj.protected.Zip != "" {
plaintext, err = decompress(obj.protected.Zip, plaintext)
}

return plaintext, err
}

// DecryptMulti decrypts and validates the object and returns the plaintexts,
// with support for multiple recipients. It returns the index of the recipient
// for which the decryption was successful, the merged headers for that recipient,
// and the plaintext.
func (obj JsonWebEncryption) DecryptMulti(decryptionKey interface{}) (int, JoseHeader, []byte, error) {
globalHeaders := obj.mergedHeaders(nil)

if len(globalHeaders.Crit) > 0 {
return -1, JoseHeader{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
}

decrypter, err := newDecrypter(decryptionKey)
if err != nil {
return -1, JoseHeader{}, nil, err
}

cipher := getContentCipher(globalHeaders.Enc)
if cipher == nil {
return -1, JoseHeader{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(globalHeaders.Enc))
}

generator := randomKeyGenerator{
size: cipher.keySize(),
}

parts := &aeadParts{
iv: obj.iv,
ciphertext: obj.ciphertext,
tag: obj.tag,
}

authData := obj.computeAuthData()

index := -1
var plaintext []byte
var headers rawHeader

for i, recipient := range obj.recipients {
recipientHeaders := obj.mergedHeaders(&recipient)

cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
if err == nil {
index = i
headers = recipientHeaders
break
}
}
}

if plaintext == nil {
return nil, ErrCryptoFailure
if plaintext == nil || err != nil {
return -1, JoseHeader{}, nil, ErrCryptoFailure
}

// The "zip" header parameter may only be present in the protected header.
if obj.protected.Zip != "" {
plaintext, err = decompress(obj.protected.Zip, plaintext)
}

return plaintext, err
return index, headers.sanitized(), plaintext, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be cleaner to not return index, headers and plaintext if err is set here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code here will not execute if decryption failed (will return on the plaintext is nil check above).

}
35 changes: 18 additions & 17 deletions crypter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func TestMultiRecipientJWE(t *testing.T) {

err = enc.AddRecipient(RSA_OAEP, &rsaTestKey.PublicKey)
if err != nil {
t.Error("error when adding RSA recipient", err)
t.Fatal("error when adding RSA recipient", err)
}

sharedKey := []byte{
Expand All @@ -282,45 +282,46 @@ func TestMultiRecipientJWE(t *testing.T) {

err = enc.AddRecipient(A256GCMKW, sharedKey)
if err != nil {
t.Error("error when adding AES recipient: ", err)
return
t.Fatal("error when adding AES recipient: ", err)
}

input := []byte("Lorem ipsum dolor sit amet")
obj, err := enc.Encrypt(input)
if err != nil {
t.Error("error in encrypt: ", err)
return
t.Fatal("error in encrypt: ", err)
}

msg := obj.FullSerialize()

parsed, err := ParseEncrypted(msg)
if err != nil {
t.Error("error in parse: ", err)
return
t.Fatal("error in parse: ", err)
}

output, err := parsed.Decrypt(rsaTestKey)
i, _, output, err := parsed.DecryptMulti(rsaTestKey)
if err != nil {
t.Error("error on decrypt with RSA: ", err)
return
t.Fatal("error on decrypt with RSA: ", err)
}

if i != 0 {
t.Fatal("recipient index should be 0 for RSA key")
}

if bytes.Compare(input, output) != 0 {
t.Error("Decrypted output does not match input: ", output, input)
return
t.Fatal("Decrypted output does not match input: ", output, input)
}

output, err = parsed.Decrypt(sharedKey)
i, _, output, err = parsed.DecryptMulti(sharedKey)
if err != nil {
t.Error("error on decrypt with AES: ", err)
return
t.Fatal("error on decrypt with AES: ", err)
}

if i != 1 {
t.Fatal("recipient index should be 1 for shared key")
}

if bytes.Compare(input, output) != 0 {
t.Error("Decrypted output does not match input", output, input)
return
t.Fatal("Decrypted output does not match input", output, input)
}
}

Expand Down
5 changes: 4 additions & 1 deletion jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ type rawSignatureInfo struct {

// JsonWebSignature represents a signed JWS object after parsing.
type JsonWebSignature struct {
payload []byte
payload []byte
// Signatures attached to this object (may be more than one for multi-sig).
// Be careful about accessing these directly, prefer to use Verify() or
// VerifyMulti() to ensure that the data you're getting is verified.
Signatures []Signature
}

Expand Down
40 changes: 37 additions & 3 deletions signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
)

Expand Down Expand Up @@ -193,13 +194,46 @@ func (ctx *genericSigner) SetEmbedJwk(embed bool) {
}

// Verify validates the signature on the object and returns the payload.
// Note that this function does not support multi-signature, if you desire
// multi-sig verification use VerifyMulti instead.
func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error) {
verifier, err := newVerifier(verificationKey)
if err != nil {
return nil, err
}

for _, signature := range obj.Signatures {
if len(obj.Signatures) > 1 {
return nil, errors.New("square/go-jose: too many signatures in payload; expecting only one")
}

signature := obj.Signatures[0]
headers := signature.mergedHeaders()
if len(headers.Crit) > 0 {
// Unsupported crit header
return nil, ErrCryptoFailure
}

input := obj.computeAuthData(&signature)
alg := SignatureAlgorithm(headers.Alg)
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return obj.payload, nil
}

return nil, ErrCryptoFailure
}

// VerifyMulti validates (one of the multiple) signatures on the object and
// returns the index of the signature that was verified, along with the signature
// object and the payload. We return the signature and index to guarantee that
// callers are getting the verified value.
func (obj JsonWebSignature) VerifyMulti(verificationKey interface{}) (int, Signature, []byte, error) {
verifier, err := newVerifier(verificationKey)
if err != nil {
return -1, Signature{}, nil, err
}

for i, signature := range obj.Signatures {
headers := signature.mergedHeaders()
if len(headers.Crit) > 0 {
// Unsupported crit header
Expand All @@ -210,9 +244,9 @@ func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error)
alg := SignatureAlgorithm(headers.Alg)
err := verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return obj.payload, nil
return i, signature, obj.payload, nil
}
}

return nil, ErrCryptoFailure
return -1, Signature{}, nil, ErrCryptoFailure
}
32 changes: 17 additions & 15 deletions signing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,43 +224,45 @@ func TestMultiRecipientJWS(t *testing.T) {
input := []byte("Lorem ipsum dolor sit amet")
obj, err := signer.Sign(input)
if err != nil {
t.Error("error on sign: ", err)
return
t.Fatal("error on sign: ", err)
}

_, err = obj.CompactSerialize()
if err == nil {
t.Error("message with multiple recipient was compact serialized")
t.Fatal("message with multiple recipient was compact serialized")
}

msg := obj.FullSerialize()

obj, err = ParseSigned(msg)
if err != nil {
t.Error("error on parse: ", err)
return
t.Fatal("error on parse: ", err)
}

output, err := obj.Verify(&rsaTestKey.PublicKey)
i, _, output, err := obj.VerifyMulti(&rsaTestKey.PublicKey)
if err != nil {
t.Error("error on verify: ", err)
return
t.Fatal("error on verify: ", err)
}

if i != 0 {
t.Fatal("signature index should be 0 for RSA key")
}

if bytes.Compare(output, input) != 0 {
t.Error("input/output do not match", output, input)
return
t.Fatal("input/output do not match", output, input)
}

output, err = obj.Verify(sharedKey)
i, _, output, err = obj.VerifyMulti(sharedKey)
if err != nil {
t.Error("error on verify: ", err)
return
t.Fatal("error on verify: ", err)
}

if i != 1 {
t.Fatal("signature index should be 1 for EC key")
}

if bytes.Compare(output, input) != 0 {
t.Error("input/output do not match", output, input)
return
t.Fatal("input/output do not match", output, input)
}
}

Expand Down