diff --git a/identity/credentials_oidc.go b/identity/credentials_oidc.go index 1ad0ada209c1..3709d6878384 100644 --- a/identity/credentials_oidc.go +++ b/identity/credentials_oidc.go @@ -29,6 +29,9 @@ type CredentialsOIDCProvider struct { InitialIDToken string `json:"initial_id_token"` InitialAccessToken string `json:"initial_access_token"` InitialRefreshToken string `json:"initial_refresh_token"` + CurrentIDToken string `json:"current_id_token"` + CurrentAccessToken string `json:"current_access_token"` + CurrentRefreshToken string `json:"current_refresh_token"` } // NewCredentialsOIDC creates a new OIDC credential. @@ -50,6 +53,9 @@ func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject st InitialIDToken: idToken, InitialAccessToken: accessToken, InitialRefreshToken: refreshToken, + CurrentIDToken: idToken, + CurrentAccessToken: accessToken, + CurrentRefreshToken: refreshToken, }}, }); err != nil { return nil, errors.WithStack(x.PseudoPanic. @@ -66,3 +72,12 @@ func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject st func OIDCUniqueID(provider, subject string) string { return fmt.Sprintf("%s:%s", provider, subject) } + +func (c *CredentialsOIDC) GetProvider(provider, subject string) (k int, found bool) { + for k, p := range c.Providers { + if p.Subject == subject && p.Provider == provider { + return k, true + } + } + return -1, false +} diff --git a/identity/identity.go b/identity/identity.go index 59ce0daedab5..a9967d90b9cb 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -406,7 +406,7 @@ func (i *Identity) WithDeclassifiedCredentials(ctx context.Context, c cipher.Pro toPublish := original toPublish.Config = []byte{} - for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token"} { + for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token", "current_id_token", "current_access_token", "current_refresh_token"} { var i int var err error gjson.GetBytes(original.Config, "providers").ForEach(func(_, v gjson.Result) bool { diff --git a/identity/manager.go b/identity/manager.go index f5b7a7959a8a..77f40e5d0ae0 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -167,6 +167,33 @@ func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...Manager return m.r.PrivilegedIdentityPool().UpdateIdentity(ctx, updated) } +func (m *Manager) UpdateCredentials(ctx context.Context, id uuid.UUID, ct CredentialsType, cb func(*Credentials) error, opts ...ManagerOption) (err error) { + ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.Update") + defer otelx.End(span, &err) + + updated, err := m.r.PrivilegedIdentityPool().GetIdentityConfidential(ctx, id) + if err != nil { + return err + } + + c, ok := updated.GetCredentials(ct) + if !ok { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected credentials of type %s to exist but they did not.", ct)) + } + + if err := cb(c); err != nil { + return err + } + + updated.SetCredentials(ct, *c) + o := newManagerOptions(opts) + if err := m.ValidateIdentity(ctx, updated, o); err != nil { + return err + } + + return m.r.PrivilegedIdentityPool().UpdateIdentity(ctx, updated) +} + func (m *Manager) UpdateSchemaID(ctx context.Context, id uuid.UUID, schemaID string, opts ...ManagerOption) (err error) { ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.UpdateSchemaID") defer otelx.End(span, &err) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index edf3faeb80ae..f78e4d9586d3 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -132,21 +132,61 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login var o identity.CredentialsOIDC if err := json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&o); err != nil { - return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error()))) + return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The oidc credentials could not be decoded properly").WithDebug(err.Error()))) } sess := session.NewInactiveSession() sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID) + found := false for _, c := range o.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, a, i, sess, provider.Config().ID); err != nil { - return nil, s.handleError(w, r, a, provider.Config().ID, nil, err) + found = true + break + } + } + + if !found { + return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject))) + } + + if err := s.d.IdentityManager().UpdateCredentials(r.Context(), i.ID, identity.CredentialsTypeOIDC, func(toUpdate *identity.Credentials) error { + var toUpdateConfig identity.CredentialsOIDC + if err := json.Unmarshal(toUpdate.Config, &toUpdateConfig); err != nil { + return err + } + + k, found := toUpdateConfig.GetProvider(provider.Config().ID, claims.Subject) + if !found { + return nil + } + + if idToken, ok := token.Extra("id_token").(string); ok { + if toUpdateConfig.Providers[k].CurrentIDToken, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(idToken)); err != nil { + return err } - return nil, nil } + + if toUpdateConfig.Providers[k].CurrentAccessToken, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(token.AccessToken)); err != nil { + return err + } + + if toUpdateConfig.Providers[k].CurrentRefreshToken, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(token.RefreshToken)); err != nil { + return err + } + + if toUpdate.Config, err = json.Marshal(toUpdateConfig); err != nil { + return err + } + + return nil + }); err != nil { + return nil, s.handleError(w, r, a, provider.Config().ID, nil, err) } - return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject))) + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, a, i, sess, provider.Config().ID); err != nil { + return nil, s.handleError(w, r, a, provider.Config().ID, nil, err) + } + return nil, nil } func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ uuid.UUID) (i *identity.Identity, err error) { diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index 1866cd82aaab..ea7148a36850 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -428,6 +428,9 @@ func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdat InitialAccessToken: cat, InitialRefreshToken: crt, InitialIDToken: it, + CurrentAccessToken: cat, + CurrentRefreshToken: crt, + CurrentIDToken: it, }) creds.Config, err = json.Marshal(conf)