Skip to content

Commit

Permalink
Review: improve Azure handler.
Browse files Browse the repository at this point in the history
- replace `context.WithTimeout` with `s.Clock.After` to avoid `time.Sleep` in tests.
  • Loading branch information
Tener committed Jan 5, 2023
1 parent d89654c commit 54ccb28
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
28 changes: 16 additions & 12 deletions lib/srv/app/azure/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,23 @@ const getTokenTimeout = time.Second * 5
func (s *handler) getToken(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) {
key := cacheKey{managedIdentity, scope}

timeoutCtx, cancel := context.WithTimeout(ctx, getTokenTimeout)
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()

token, err := utils.FnCacheGet(timeoutCtx, s.tokenCache, key, func(ctx context.Context) (*azcore.AccessToken, error) {
return s.getAccessToken(ctx, managedIdentity, scope)
})

if err != nil {
if timeoutCtx.Err() == err {
return nil, trace.Wrap(err, "timeout waiting for access token for %v", getTokenTimeout)
}
return nil, trace.Wrap(err)
var tokenResult *azcore.AccessToken
var errorResult error

go func() {
tokenResult, errorResult = utils.FnCacheGet(cancelCtx, s.tokenCache, key, func(ctx context.Context) (*azcore.AccessToken, error) {
return s.getAccessToken(ctx, managedIdentity, scope)
})
cancel()
}()

select {
case <-s.Clock.After(getTokenTimeout):
return nil, trace.Wrap(context.DeadlineExceeded, "timeout waiting for access token for %v", getTokenTimeout)
case <-cancelCtx.Done():
return tokenResult, errorResult
}

return token, nil
}
20 changes: 17 additions & 3 deletions lib/srv/app/azure/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
func TestForwarder_getToken(t *testing.T) {
t.Parallel()

tests := []struct {
type testCase struct {
name string

config HandlerConfig
Expand All @@ -38,7 +38,11 @@ func TestForwarder_getToken(t *testing.T) {

wantToken *azcore.AccessToken
checkErr require.ErrorAssertionFunc
}{
}

var tests []testCase

tests = []testCase{
{
name: "base case",
config: HandlerConfig{
Expand All @@ -60,8 +64,18 @@ func TestForwarder_getToken(t *testing.T) {
{
name: "timeout",
config: HandlerConfig{
Clock: clockwork.NewFakeClock(),
getAccessToken: func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) {
time.Sleep(getTokenTimeout * 2)
// find the fake clock from above
var clock clockwork.FakeClock
for _, test := range tests {
if test.name == "timeout" {
clock = test.config.Clock.(clockwork.FakeClock)
}
}

clock.Advance(getTokenTimeout)
clock.Sleep(getTokenTimeout * 2)
return &azcore.AccessToken{Token: "foobar"}, nil
},
},
Expand Down

0 comments on commit 54ccb28

Please # to comment.