Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Commit

Permalink
Improve multi-recipient/multi-sig handling
Browse files Browse the repository at this point in the history
  • Loading branch information
csstaub committed Sep 22, 2016
1 parent d00415a commit 4ab2177
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 40 deletions.
75 changes: 71 additions & 4 deletions crypter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
"reflect"
)
Expand Down Expand Up @@ -292,10 +293,16 @@ func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JsonWe
return obj, nil
}

// Decrypt and validate the object and return the plaintext.
// Decrypt and validate the object and return the plaintext. Note that this
// function does not support multi-recipient, if you desire multi-recipient
// decryption use DecryptMulti instead.
func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
headers := obj.mergedHeaders(nil)

if len(obj.recipients) > 1 {
return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
}

if len(headers.Crit) > 0 {
return nil, fmt.Errorf("square/go-jose: unsupported crit header")
}
Expand Down Expand Up @@ -323,27 +330,87 @@ func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error)
authData := obj.computeAuthData()

var plaintext []byte
for _, recipient := range obj.recipients {
recipient := obj.recipients[0]
recipientHeaders := obj.mergedHeaders(&recipient)

cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
}

if plaintext == nil {
return nil, ErrCryptoFailure
}

// The "zip" header parameter may only be present in the protected header.
if obj.protected.Zip != "" {
plaintext, err = decompress(obj.protected.Zip, plaintext)
}

return plaintext, err
}

// DecryptMulti decrypts and validates the object and returns the plaintexts,
// with support for multiple recipients. It returns the index of the recipient
// for which the decryption was successful, the merged headers for that recipient,
// and the plaintext.
func (obj JsonWebEncryption) DecryptMulti(decryptionKey interface{}) (int, JoseHeader, []byte, error) {
globalHeaders := obj.mergedHeaders(nil)

if len(globalHeaders.Crit) > 0 {
return -1, JoseHeader{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
}

decrypter, err := newDecrypter(decryptionKey)
if err != nil {
return -1, JoseHeader{}, nil, err
}

cipher := getContentCipher(globalHeaders.Enc)
if cipher == nil {
return -1, JoseHeader{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(globalHeaders.Enc))
}

generator := randomKeyGenerator{
size: cipher.keySize(),
}

parts := &aeadParts{
iv: obj.iv,
ciphertext: obj.ciphertext,
tag: obj.tag,
}

authData := obj.computeAuthData()

index := -1
var plaintext []byte
var headers rawHeader

for i, recipient := range obj.recipients {
recipientHeaders := obj.mergedHeaders(&recipient)

cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
if err == nil {
index = i
headers = recipientHeaders
break
}
}
}

if plaintext == nil {
return nil, ErrCryptoFailure
return -1, JoseHeader{}, nil, ErrCryptoFailure
}

// The "zip" header parameter may only be present in the protected header.
if obj.protected.Zip != "" {
plaintext, err = decompress(obj.protected.Zip, plaintext)
}

return plaintext, err
return index, headers.sanitized(), plaintext, err
}
35 changes: 18 additions & 17 deletions crypter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func TestMultiRecipientJWE(t *testing.T) {

err = enc.AddRecipient(RSA_OAEP, &rsaTestKey.PublicKey)
if err != nil {
t.Error("error when adding RSA recipient", err)
t.Fatal("error when adding RSA recipient", err)
}

sharedKey := []byte{
Expand All @@ -282,45 +282,46 @@ func TestMultiRecipientJWE(t *testing.T) {

err = enc.AddRecipient(A256GCMKW, sharedKey)
if err != nil {
t.Error("error when adding AES recipient: ", err)
return
t.Fatal("error when adding AES recipient: ", err)
}

input := []byte("Lorem ipsum dolor sit amet")
obj, err := enc.Encrypt(input)
if err != nil {
t.Error("error in encrypt: ", err)
return
t.Fatal("error in encrypt: ", err)
}

msg := obj.FullSerialize()

parsed, err := ParseEncrypted(msg)
if err != nil {
t.Error("error in parse: ", err)
return
t.Fatal("error in parse: ", err)
}

output, err := parsed.Decrypt(rsaTestKey)
i, _, output, err := parsed.DecryptMulti(rsaTestKey)
if err != nil {
t.Error("error on decrypt with RSA: ", err)
return
t.Fatal("error on decrypt with RSA: ", err)
}

if i != 0 {
t.Fatal("recipient index should be 0 for RSA key")
}

if bytes.Compare(input, output) != 0 {
t.Error("Decrypted output does not match input: ", output, input)
return
t.Fatal("Decrypted output does not match input: ", output, input)
}

output, err = parsed.Decrypt(sharedKey)
i, _, output, err = parsed.DecryptMulti(sharedKey)
if err != nil {
t.Error("error on decrypt with AES: ", err)
return
t.Fatal("error on decrypt with AES: ", err)
}

if i != 1 {
t.Fatal("recipient index should be 0 for RSA key")
}

if bytes.Compare(input, output) != 0 {
t.Error("Decrypted output does not match input", output, input)
return
t.Fatal("Decrypted output does not match input", output, input)
}
}

Expand Down
5 changes: 4 additions & 1 deletion jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ type rawSignatureInfo struct {

// JsonWebSignature represents a signed JWS object after parsing.
type JsonWebSignature struct {
payload []byte
payload []byte
// Signatures attached to this object (may be more than one for multi-sig).
// Be careful about accessing these directly, prefer to use Verify() or
// VerifyMulti() to ensure that the data you're getting is verified.
Signatures []Signature
}

Expand Down
40 changes: 37 additions & 3 deletions signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
)

Expand Down Expand Up @@ -193,13 +194,46 @@ func (ctx *genericSigner) SetEmbedJwk(embed bool) {
}

// Verify validates the signature on the object and returns the payload.
// Note that this function does not support multi-signature, if you desire
// multi-sig verification use VerifyMulti instead.
func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error) {
verifier, err := newVerifier(verificationKey)
if err != nil {
return nil, err
}

for _, signature := range obj.Signatures {
if len(obj.Signatures) > 1 {
return nil, errors.New("square/go-jose: too many signatures in payload; expecting only one")
}

signature := obj.Signatures[0]
headers := signature.mergedHeaders()
if len(headers.Crit) > 0 {
// Unsupported crit header
return nil, ErrCryptoFailure
}

input := obj.computeAuthData(&signature)
alg := SignatureAlgorithm(headers.Alg)
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return obj.payload, nil
}

return nil, ErrCryptoFailure
}

// VerifyMulti validates (one of the multiple) signatures on the object and
// returns the index of the signature that was verified, along with the signature
// object and the payload. We return the signature and index to guarantee that
// callers are getting the verified value.
func (obj JsonWebSignature) VerifyMulti(verificationKey interface{}) (int, Signature, []byte, error) {
verifier, err := newVerifier(verificationKey)
if err != nil {
return -1, Signature{}, nil, err
}

for i, signature := range obj.Signatures {
headers := signature.mergedHeaders()
if len(headers.Crit) > 0 {
// Unsupported crit header
Expand All @@ -210,9 +244,9 @@ func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error)
alg := SignatureAlgorithm(headers.Alg)
err := verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return obj.payload, nil
return i, signature, obj.payload, nil
}
}

return nil, ErrCryptoFailure
return -1, Signature{}, nil, ErrCryptoFailure
}
32 changes: 17 additions & 15 deletions signing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,43 +224,45 @@ func TestMultiRecipientJWS(t *testing.T) {
input := []byte("Lorem ipsum dolor sit amet")
obj, err := signer.Sign(input)
if err != nil {
t.Error("error on sign: ", err)
return
t.Fatal("error on sign: ", err)
}

_, err = obj.CompactSerialize()
if err == nil {
t.Error("message with multiple recipient was compact serialized")
t.Fatal("message with multiple recipient was compact serialized")
}

msg := obj.FullSerialize()

obj, err = ParseSigned(msg)
if err != nil {
t.Error("error on parse: ", err)
return
t.Fatal("error on parse: ", err)
}

output, err := obj.Verify(&rsaTestKey.PublicKey)
i, _, output, err := obj.VerifyMulti(&rsaTestKey.PublicKey)
if err != nil {
t.Error("error on verify: ", err)
return
t.Fatal("error on verify: ", err)
}

if i != 0 {
t.Fatal("signature index should be 0 for RSA key")
}

if bytes.Compare(output, input) != 0 {
t.Error("input/output do not match", output, input)
return
t.Fatal("input/output do not match", output, input)
}

output, err = obj.Verify(sharedKey)
i, _, output, err = obj.VerifyMulti(sharedKey)
if err != nil {
t.Error("error on verify: ", err)
return
t.Fatal("error on verify: ", err)
}

if i != 1 {
t.Fatal("signature index should be 1 for EC key")
}

if bytes.Compare(output, input) != 0 {
t.Error("input/output do not match", output, input)
return
t.Fatal("input/output do not match", output, input)
}
}

Expand Down

0 comments on commit 4ab2177

Please # to comment.