Skip to content

Commit

Permalink
Implement HMAC Reset() and make Sum() friendlier
Browse files Browse the repository at this point in the history
re #7
  • Loading branch information
Richard Kettlewell authored and Richard Kettlewell committed Aug 7, 2018
1 parent e5c8975 commit dbbc822
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 33 deletions.
91 changes: 62 additions & 29 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package crypto11

import (
"context"
"errors"
"fmt"
"github.com/miekg/pkcs11"
"github.com/youtube/vitess/go/pools"
Expand Down Expand Up @@ -56,12 +57,23 @@ type hmacImplementation struct {
// PKCS#11 session to use
session *PKCS11Session

// Signing key
key *PKCS11SecretKey

// Hash size
size int

// Block size
blockSize int

// PKCS#11 mechanism information
mechDescription []*pkcs11.Mechanism

// Cleanup function
cleanup func()

// Result, or nil if we don't have the answer yet
result []byte
}

type hmacInfo struct {
Expand Down Expand Up @@ -91,6 +103,9 @@ var hmacInfos = map[int]*hmacInfo{
pkcs11.CKM_RIPEMD160_HMAC_GENERAL: {20, 64, true},
}

// ErrHmacClosed is called if an HMAC is updated after it has finished.
var ErrHmacClosed = errors.New("already called Sum()")

// NewHMAC returns a new HMAC hash using the given PKCS#11 mechanism
// and key.
// length specifies the output size, for _GENERAL mechanisms.
Expand All @@ -99,14 +114,38 @@ var hmacInfos = map[int]*hmacInfo{
// Size() function will return whatever length was, even if it is wrong.
// BlockSize() will always return 0 in this case.
//
// The Reset() method is not implemented, and Sum() may only be called once.
// The former limitation may be lifted in future but the latter is fundamental
// and will not change.
// The Reset() method is not implemented.
// After Sum() is called no new data may be added.
func (key *PKCS11SecretKey) NewHMAC(mech int, length int) (h hash.Hash, err error) {
var hi hmacImplementation
hi = hmacImplementation{
key: key,
}
var params []byte
if info, ok := hmacInfos[mech]; ok {
hi.blockSize = info.blockSize
if info.general {
hi.size = length
params = ulongToBytes(uint(length))
} else {
hi.size = info.size
}
} else {
hi.size = length
}
hi.mechDescription = []*pkcs11.Mechanism{pkcs11.NewMechanism(uint(mech), params)}
if err = hi.initialize(); err != nil {
return
}
h = &hi
return
}

func (hi *hmacImplementation) initialize() (err error) {
// TODO refactor with newBlockModeCloser
sessionPool := pool.Get(key.Slot)
sessionPool := pool.Get(hi.key.Slot)
if sessionPool == nil {
err = fmt.Errorf("crypto11: no session for slot %d", key.Slot)
err = fmt.Errorf("crypto11: no session for slot %d", hi.key.Slot)
return
}
ctx, cancel := context.WithTimeout(context.Background(), newSessionTimeout)
Expand All @@ -115,34 +154,26 @@ func (key *PKCS11SecretKey) NewHMAC(mech int, length int) (h hash.Hash, err erro
if session, err = sessionPool.Get(ctx); err != nil {
return
}
var hi hmacImplementation
hi.session = session.(*PKCS11Session)
hi.cleanup = func() {
sessionPool.Put(session)
hi.session = nil
}
var params []byte
if info, ok := hmacInfos[mech]; ok {
hi.blockSize = info.blockSize
if info.general {
hi.size = length
params = ulongToBytes(uint(length))
} else {
hi.size = info.size
}
} else {
hi.size = length
}
mechDescription := []*pkcs11.Mechanism{pkcs11.NewMechanism(uint(mech), params)}
if err = hi.session.Ctx.SignInit(hi.session.Handle, mechDescription, key.Handle); err != nil {
if err = hi.session.Ctx.SignInit(hi.session.Handle, hi.mechDescription, hi.key.Handle); err != nil {
hi.cleanup()
return
}
h = &hi
hi.result = nil
return
}

func (hi *hmacImplementation) Write(p []byte) (n int, err error) {
if hi.result != nil {
if len(p) > 0 {
err = ErrHmacClosed
}
return
}
if err = hi.session.Ctx.#date(hi.session.Handle, p); err != nil {
return
}
Expand All @@ -151,18 +182,20 @@ func (hi *hmacImplementation) Write(p []byte) (n int, err error) {
}

func (hi *hmacImplementation) Sum(b []byte) []byte {
var result []byte
var err error
result, err = hi.session.Ctx.SignFinal(hi.session.Handle)
hi.cleanup()
if err != nil {
panic(err)
if hi.result == nil {
var err error
hi.result, err = hi.session.Ctx.SignFinal(hi.session.Handle)
hi.cleanup()
if err != nil {
panic(err)
}
}
return append(b, result...)
return append(b, hi.result...)
}

func (hi *hmacImplementation) Reset() {
panic("Reset not implemented")
hi.Sum(nil) // Clean up
hi.initialize()
}

func (hi *hmacImplementation) Size() int {
Expand Down
70 changes: 66 additions & 4 deletions hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ func TestHmac(t *testing.T) {
t.Skipf("HMAC not implemented on SoftHSM")
}
t.Run("HMACSHA1", func(t *testing.T) {
testHmac(t, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC, 0, 20)
testHmac(t, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC, 0, 20, false)
})
t.Run("HMACSHA1General", func(t *testing.T) {
testHmac(t, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC_GENERAL, 10, 10)
testHmac(t, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC_GENERAL, 10, 10, true)
})
t.Run("HMACSHA256", func(t *testing.T) {
testHmac(t, pkcs11.CKK_SHA256_HMAC, pkcs11.CKM_SHA256_HMAC, 0, 32)
testHmac(t, pkcs11.CKK_SHA256_HMAC, pkcs11.CKM_SHA256_HMAC, 0, 32, false)
})
Close()
}

func testHmac(t *testing.T, keytype int, mech int, length int, xlength int) {
func testHmac(t *testing.T, keytype int, mech int, length int, xlength int, full bool) {
var err error
var key *PKCS11SecretKey
t.Run("Generate", func(t *testing.T) {
Expand Down Expand Up @@ -97,4 +97,66 @@ func testHmac(t *testing.T, keytype int, mech int, length int, xlength int) {
return
}
})
if full { // Independent of hash, only do these once
t.Run("MultiSum", func(t *testing.T) {
input := []byte("a different short string")
var h1 hash.Hash
if h1, err = key.NewHMAC(mech, length); err != nil {
t.Errorf("key.NewHMAC: %v", err)
return
}
if n, err := h1.Write(input); err != nil || n != len(input) {
t.Errorf("h1.Write: %v/%d", err, n)
return
}
r1 := h1.Sum([]byte{})
r2 := h1.Sum([]byte{})
if bytes.Compare(r1, r2) != 0 {
t.Errorf("r1/r2 inconsistent")
return
}
// Can't add more after Sum()
if n, err := h1.Write(input); err != ErrHmacClosed {
t.Errorf("h1.Write: %v/%d", err, n)
return
}
// 0-length is special
if n, err := h1.Write([]byte{}); err != nil || n != 0 {
t.Errorf("h1.Write: %v/%d", err, n)
return
}
})
t.Run("Reset", func(t *testing.T) {
var h1 hash.Hash
if h1, err = key.NewHMAC(mech, length); err != nil {
t.Errorf("key.NewHMAC: %v", err)
return
}
if n, err := h1.Write([]byte{1}); err != nil || n != 1 {
t.Errorf("h1.Write: %v/%d", err, n)
return
}
r1 := h1.Sum([]byte{})
h1.Reset()
if n, err := h1.Write([]byte{2}); err != nil || n != 1 {
t.Errorf("h1.Write: %v/%d", err, n)
return
}
r2 := h1.Sum([]byte{})
h1.Reset()
if n, err := h1.Write([]byte{1}); err != nil || n != 1 {
t.Errorf("h1.Write: %v/%d", err, n)
return
}
r3 := h1.Sum([]byte{})
if bytes.Compare(r1, r3) != 0 {
t.Errorf("r1/r3 inconsistent")
return
}
if bytes.Compare(r1, r2) == 0 {
t.Errorf("r1/r2 unexpectedly equal")
return
}
})
}
}

0 comments on commit dbbc822

Please # to comment.