Skip to content

Commit

Permalink
[v8] Open a new remote client when the remote site has changed in a w…
Browse files Browse the repository at this point in the history
…eb session (#13968)

* Open a new remote client when the remote site has changed in a web session

* Test coverage for remoteClientCache
  • Loading branch information
espadolini authored Jun 29, 2022
1 parent 3f553eb commit 532a373
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 32 deletions.
94 changes: 62 additions & 32 deletions lib/web/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ type SessionContext struct {
// clt holds a connection to the root auth. Note that requests made using this
// client are made with the identity of the user and are NOT cached.
clt *auth.Client
// remoteClientCache holds the remote clients that have been used in this
// session.
remoteClientCache

// unsafeCachedAuthClient holds a read-only cache to root auth. Note this access
// point cache is authenticated with the identity of the node, not of the
Expand All @@ -77,9 +80,6 @@ type SessionContext struct {
resources *sessionResources
// session refers the web session created for the user.
session types.WebSession

mu sync.Mutex
remoteClt map[string]auth.ClientI
}

// String returns the text representation of this context
Expand Down Expand Up @@ -127,19 +127,6 @@ func (c *SessionContext) validateBearerToken(ctx context.Context, token string)
return nil
}

func (c *SessionContext) addRemoteClient(siteName string, remoteClient auth.ClientI) {
c.mu.Lock()
defer c.mu.Unlock()
c.remoteClt[siteName] = remoteClient
}

func (c *SessionContext) getRemoteClient(siteName string) (auth.ClientI, bool) {
c.mu.Lock()
defer c.mu.Unlock()
remoteClt, ok := c.remoteClt[siteName]
return remoteClt, ok
}

// GetClient returns the client connected to the auth server
func (c *SessionContext) GetClient() (auth.ClientI, error) {
return c.clt, nil
Expand Down Expand Up @@ -167,7 +154,7 @@ func (c *SessionContext) GetUserClient(site reversetunnel.RemoteSite) (auth.Clie
}

// check if we already have a connection to this cluster
remoteClt, ok := c.getRemoteClient(site.GetName())
remoteClt, ok := c.getRemoteClient(site)
if !ok {
rClt, err := c.newRemoteClient(site)
if err != nil {
Expand All @@ -177,7 +164,10 @@ func (c *SessionContext) GetUserClient(site reversetunnel.RemoteSite) (auth.Clie
// we'll save the remote client in our session context so we don't have to
// build a new connection next time. all remote clients will be closed when
// the session context is closed.
c.addRemoteClient(site.GetName(), rClt)
err = c.addRemoteClient(site, rClt)
if err != nil {
c.log.WithError(err).Info("Failed closing stale remote client for site: ", site.GetName())
}

return rClt, nil
}
Expand Down Expand Up @@ -211,7 +201,7 @@ func (c *SessionContext) tryRemoteTLSClient(cluster reversetunnel.RemoteSite) (a
}
_, err = clt.GetDomainName()
if err != nil {
return clt, trace.Wrap(err)
return nil, trace.NewAggregate(err, clt.Close())
}
return clt, nil
}
Expand Down Expand Up @@ -394,18 +384,7 @@ func (c *SessionContext) GetSessionID() string {
// Close cleans up resources associated with this context and removes it
// from the user context
func (c *SessionContext) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
var errors []error
for _, clt := range c.remoteClt {
if err := clt.Close(); err != nil {
errors = append(errors, err)
}
}
if err := c.clt.Close(); err != nil {
errors = append(errors, err)
}
return trace.NewAggregate(errors...)
return trace.NewAggregate(c.remoteClientCache.Close(), c.clt.Close())
}

// getToken returns the bearer token associated with the underlying
Expand Down Expand Up @@ -825,7 +804,6 @@ func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (*
ctx := &SessionContext{
clt: userClient,
unsafeCachedAuthClient: s.accessPoint,
remoteClt: make(map[string]auth.ClientI),
user: session.GetUser(),
session: session,
parent: s,
Expand Down Expand Up @@ -986,3 +964,55 @@ func (h *Handler) waitForWebSession(ctx context.Context, req types.GetWebSession
}
return trace.Wrap(err)
}

// remoteClientCache stores remote clients keyed by site name while also keeping
// track of the actual remote site associated with the client (in case the
// remote site has changed). Safe for concurrent access. Closes all clients and
// wipes the cache on Close.
type remoteClientCache struct {
sync.Mutex
clients map[string]struct {
auth.ClientI
reversetunnel.RemoteSite
}
}

func (c *remoteClientCache) addRemoteClient(site reversetunnel.RemoteSite, remoteClient auth.ClientI) error {
c.Lock()
defer c.Unlock()
if c.clients == nil {
c.clients = make(map[string]struct {
auth.ClientI
reversetunnel.RemoteSite
})
}
var err error
if c.clients[site.GetName()].ClientI != nil {
err = c.clients[site.GetName()].ClientI.Close()
}
c.clients[site.GetName()] = struct {
auth.ClientI
reversetunnel.RemoteSite
}{remoteClient, site}
return err
}

func (c *remoteClientCache) getRemoteClient(site reversetunnel.RemoteSite) (auth.ClientI, bool) {
c.Lock()
defer c.Unlock()
remoteClt, ok := c.clients[site.GetName()]
return remoteClt.ClientI, ok && remoteClt.RemoteSite == site
}

func (c *remoteClientCache) Close() error {
c.Lock()
defer c.Unlock()

errors := make([]error, 0, len(c.clients))
for _, clt := range c.clients {
errors = append(errors, clt.ClientI.Close())
}
c.clients = nil

return trace.NewAggregate(errors...)
}
85 changes: 85 additions & 0 deletions lib/web/sessions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package web

import (
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/reversetunnel"
)

func TestRemoteClientCache(t *testing.T) {
t.Parallel()

openCount := 0
cache := remoteClientCache{}

sa1 := newMockRemoteSite("a")
sa2 := newMockRemoteSite("a")
sb := newMockRemoteSite("b")

err1 := errors.New("c1")
err2 := errors.New("c2")

require.NoError(t, cache.addRemoteClient(sa1, newMockClientI(&openCount, err1)))
require.Equal(t, 1, openCount)

require.ErrorIs(t, cache.addRemoteClient(sa2, newMockClientI(&openCount, nil)), err1)
require.Equal(t, 1, openCount)

require.NoError(t, cache.addRemoteClient(sb, newMockClientI(&openCount, err2)))
require.Equal(t, 2, openCount)

var aggrErr trace.Aggregate
require.ErrorAs(t, cache.Close(), &aggrErr)
require.ElementsMatch(t, []error{err2}, aggrErr.Errors())

require.Zero(t, openCount)
}

func newMockRemoteSite(name string) reversetunnel.RemoteSite {
return &mockRemoteSite{name: name}
}

type mockRemoteSite struct {
reversetunnel.RemoteSite
name string
}

func (m *mockRemoteSite) GetName() string {
return m.name
}

func newMockClientI(openCount *int, closeErr error) auth.ClientI {
*openCount++
return &mockClientI{openCount: openCount, closeErr: closeErr}
}

type mockClientI struct {
auth.ClientI
openCount *int
closeErr error
}

func (m *mockClientI) Close() error {
*m.openCount--
return m.closeErr
}

0 comments on commit 532a373

Please # to comment.