diff --git a/auth/oauth2adapt/oauth2adapt.go b/auth/oauth2adapt/oauth2adapt.go index 9835ac571cf3..e266b0c9b158 100644 --- a/auth/oauth2adapt/oauth2adapt.go +++ b/auth/oauth2adapt/oauth2adapt.go @@ -26,6 +26,13 @@ import ( "golang.org/x/oauth2/google" ) +const ( + oauth2TokenSourceKey = "oauth2.google.tokenSource" + oauth2ServiceAccountKey = "oauth2.google.serviceAccount" + authTokenSourceKey = "auth.google.tokenSource" + authServiceAccountKey = "auth.google.serviceAccount" +) + // TokenProviderFromTokenSource converts any [golang.org/x/oauth2.TokenSource] // into a [cloud.google.com/go/auth.TokenProvider]. func TokenProviderFromTokenSource(ts oauth2.TokenSource) auth.TokenProvider { @@ -47,10 +54,21 @@ func (tp *tokenProviderAdapter) Token(context.Context) (*auth.Token, error) { } return nil, err } + // Preserve compute token metadata, for both types of tokens. + metadata := map[string]interface{}{} + if val, ok := tok.Extra(oauth2TokenSourceKey).(string); ok { + metadata[authTokenSourceKey] = val + metadata[oauth2TokenSourceKey] = val + } + if val, ok := tok.Extra(oauth2ServiceAccountKey).(string); ok { + metadata[authServiceAccountKey] = val + metadata[oauth2ServiceAccountKey] = val + } return &auth.Token{ - Value: tok.AccessToken, - Type: tok.Type(), - Expiry: tok.Expiry, + Value: tok.AccessToken, + Type: tok.Type(), + Expiry: tok.Expiry, + Metadata: metadata, }, nil } @@ -76,11 +94,24 @@ func (ts *tokenSourceAdapter) Token() (*oauth2.Token, error) { } return nil, err } - return &oauth2.Token{ + tok2 := &oauth2.Token{ AccessToken: tok.Value, TokenType: tok.Type, Expiry: tok.Expiry, - }, nil + } + // Preserve token metadata. + metadata := tok.Metadata + if metadata != nil { + // Append compute token metadata in converted form. + if val, ok := metadata[authTokenSourceKey].(string); ok && val != "" { + metadata[oauth2TokenSourceKey] = val + } + if val, ok := metadata[authServiceAccountKey].(string); ok && val != "" { + metadata[oauth2ServiceAccountKey] = val + } + tok2 = tok2.WithExtra(metadata) + } + return tok2, nil } // AuthCredentialsFromOauth2Credentials converts a [golang.org/x/oauth2/google.Credentials] diff --git a/auth/oauth2adapt/oauth2adapt_test.go b/auth/oauth2adapt/oauth2adapt_test.go index d408c90c568c..808861f23b2b 100644 --- a/auth/oauth2adapt/oauth2adapt_test.go +++ b/auth/oauth2adapt/oauth2adapt_test.go @@ -252,6 +252,67 @@ func TestOauth2CredentialsFromAuthCredentials(t *testing.T) { } } +func TestMetadataConversions_ToTokenProvider(t *testing.T) { + ctx := context.Background() + tok := &oauth2.Token{AccessToken: "token"} + tok = tok.WithExtra(map[string]interface{}{ + oauth2TokenSourceKey: "compute-metadata", + oauth2ServiceAccountKey: "default", + }) + inputCreds := &google.Credentials{ + ProjectID: "test_project", + TokenSource: tokenSource{token: tok}, + JSON: []byte("json"), + UniverseDomainProvider: func() (string, error) { + return "domain", nil + }, + } + outCreds := AuthCredentialsFromOauth2Credentials(inputCreds) + tok2, err := outCreds.Token(ctx) + if err != nil { + t.Fatal(err) + } + if tok2.MetadataString(authTokenSourceKey) != "compute-metadata" { + t.Errorf("got %q, want %q", tok2.MetadataString(authTokenSourceKey), "compute-metadata") + } + if tok2.MetadataString(oauth2TokenSourceKey) != "compute-metadata" { + t.Errorf("got %q, want %q", tok2.MetadataString(oauth2TokenSourceKey), "compute-metadata") + } + if tok2.MetadataString(authServiceAccountKey) != "default" { + t.Errorf("got %q, want %q", tok2.MetadataString(authTokenSourceKey), "default") + } + if tok2.MetadataString(oauth2ServiceAccountKey) != "default" { + t.Errorf("got %q, want %q", tok2.MetadataString(oauth2ServiceAccountKey), "default") + } +} + +func TestMetadataConversions_ToTokenSource(t *testing.T) { + tok := &auth.Token{Value: "token", Metadata: map[string]interface{}{ + authTokenSourceKey: "compute-metadata", + authServiceAccountKey: "default", + }} + inputCreds := auth.NewCredentials(&auth.CredentialsOptions{ + TokenProvider: tokenProvider{token: tok}, + }) + outCreds := Oauth2CredentialsFromAuthCredentials(inputCreds) + tok2, err := outCreds.TokenSource.Token() + if err != nil { + t.Fatal(err) + } + if val, ok := tok2.Extra(oauth2TokenSourceKey).(string); !ok && val != "compute-metadata" { + t.Errorf("got %q, want %q", tok2.Extra(oauth2TokenSourceKey), "compute-metadata") + } + if val, ok := tok2.Extra(authTokenSourceKey).(string); !ok && val != "compute-metadata" { + t.Errorf("got %q, want %q", tok2.Extra(authTokenSourceKey), "compute-metadata") + } + if val, ok := tok2.Extra(oauth2ServiceAccountKey).(string); !ok && val != "default" { + t.Errorf("got %q, want %q", tok2.Extra(oauth2ServiceAccountKey), "default") + } + if val, ok := tok2.Extra(authServiceAccountKey).(string); !ok && val != "default" { + t.Errorf("got %q, want %q", tok2.Extra(authServiceAccountKey), "default") + } +} + type tokenSource struct { token *oauth2.Token err error @@ -261,10 +322,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { if ts.err != nil { return nil, ts.err } - return &oauth2.Token{ - AccessToken: ts.token.AccessToken, - TokenType: ts.token.TokenType, - }, nil + return ts.token, nil } type tokenProvider struct { @@ -276,8 +334,5 @@ func (tp tokenProvider) Token(context.Context) (*auth.Token, error) { if tp.err != nil { return nil, tp.err } - return &auth.Token{ - Value: tp.token.Value, - Type: tp.token.Type, - }, nil + return tp.token, nil }