diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index bbe8daf2a4bea..d8758add68cdf 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -163,7 +163,6 @@ func TestKube(t *testing.T) { t.Run("TrustedClustersSNI", suite.bind(testKubeTrustedClustersSNI)) t.Run("Disconnect", suite.bind(testKubeDisconnect)) t.Run("Join", suite.bind(testKubeJoin)) - t.Run("ConnectionLimit", suite.bind(testKubeConnectionLimit)) } // TestKubeExec tests kubernetes Exec command set @@ -1612,82 +1611,3 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { require.Contains(t, participantOutput, []byte("echo hi")) require.Contains(t, out.String(), []byte("echo hi2")) } - -// testKubeConnectionLimit checks that the `max_kubernetes_connections` role option is enforced. -func testKubeConnectionLimit(t *testing.T, suite *KubeSuite) { - teleport := NewInstance(InstanceConfig{ - ClusterName: Site, - HostID: HostID, - NodeName: Host, - Priv: suite.priv, - Pub: suite.pub, - log: suite.log, - }) - - const maxConnections = 10 - hostUsername := suite.me.Username - kubeGroups := []string{testImpersonationGroup} - kubeUsers := []string{"alice@example.com"} - role, err := types.NewRoleV3("kubemaster", types.RoleSpecV5{ - Allow: types.RoleConditions{ - Logins: []string{hostUsername}, - KubeGroups: kubeGroups, - KubeUsers: kubeUsers, - }, - Options: types.RoleOptions{ - MaxKubernetesConnections: maxConnections, - }, - }) - require.NoError(t, err) - teleport.AddUserWithRole(hostUsername, role) - - err = teleport.Start() - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, teleport.StopAll()) }) - - // set up kube configuration using proxy - proxyClient, proxyClientConfig, err := kubeProxyClient(kubeProxyConfig{ - t: teleport, - username: hostUsername, - kubeUsers: kubeUsers, - kubeGroups: kubeGroups, - }) - require.NoError(t, err) - - ctx := context.Background() - // try get request to fetch available pods - pod, err := proxyClient.CoreV1().Pods(testNamespace).Get(ctx, testPod, metav1.GetOptions{}) - require.NoError(t, err) - - openExec := func() error { - // interactive command, allocate pty - term := NewTerminal(250) - out := &bytes.Buffer{} - - return kubeExec(proxyClientConfig, kubeExecArgs{ - podName: pod.Name, - podNamespace: pod.Namespace, - container: pod.Spec.Containers[0].Name, - command: []string{"/bin/sh", "-c", "sleep 300"}, - stdout: out, - tty: true, - stdin: term, - }) - } - - // Create and maintain the maximum amount of open connections - for i := 0; i < maxConnections; i++ { - go openExec() - } - - // Wait for the connections to open and check for any errors - require.Eventually(t, func() bool { - trackers, err := teleport.Process.GetAuthServer().GetActiveSessionTrackers(ctx) - require.NoError(t, err) - return len(trackers) == maxConnections - }, time.Second*30, time.Second) - - // Open one more connection. It should fail due to the limit. - err = openExec() - require.Error(t, err) -} diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 3fdb5fa7d711a..84b868d0280a6 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -453,7 +453,14 @@ func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle { if err := f.authorize(req.Context(), authContext); err != nil { return nil, trace.Wrap(err) } - if err := f.acquireConnectionLock(req.Context(), authContext); err != nil { + + user := authContext.Identity.GetIdentity().Username + roles, err := getRolesByName(f, authContext.Identity.GetIdentity().Groups) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := f.AcquireConnectionLock(req.Context(), user, roles); err != nil { return nil, trace.Wrap(err) } return handler(authContext, w, req, p) @@ -897,25 +904,23 @@ func wsProxy(wsSource *websocket.Conn, wsTarget *websocket.Conn) error { return trace.Wrap(err) } -// acquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. +// AcquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. // The semaphore is releasted when the request is returned/connection is closed. // Returns an error if a semaphore could not be acquired. -func (f *Forwarder) acquireConnectionLock(ctx context.Context, identity *authContext) error { - user := identity.Identity.GetIdentity().Username - roles, err := getRolesByName(f, identity.Identity.GetIdentity().Groups) - if err != nil { - return trace.Wrap(err) +func (f *Forwarder) AcquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error { + maxConnections := roles.MaxKubernetesConnections() + if maxConnections == 0 { + return nil } - maxConnections := services.RoleSet(roles).MaxKubernetesConnections() - semLock, err := services.AcquireSemaphoreLock(ctx, services.SemaphoreLockConfig{ + _, err := services.AcquireSemaphoreLock(ctx, services.SemaphoreLockConfig{ Service: f.cfg.AuthClient, Expiry: sessionMaxLifetime, Params: types.AcquireSemaphoreRequest{ SemaphoreKind: types.SemaphoreKindKubernetesConnection, SemaphoreName: user, MaxLeases: maxConnections, - Holder: identity.teleportCluster.name, + Holder: user, }, }) if err != nil { @@ -928,7 +933,7 @@ func (f *Forwarder) acquireConnectionLock(ctx context.Context, identity *authCon return trace.Wrap(err) } - go semLock.KeepAlive(ctx) + return nil } diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index f23637993be87..51a3d9198d168 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -45,10 +45,12 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -1053,3 +1055,93 @@ func (m *mockWatcher) Events() <-chan types.Event { func (m *mockWatcher) Done() <-chan struct{} { return m.ctx.Done() } + +func newTestForwarder(ctx context.Context, cfg ForwarderConfig) *Forwarder { + return &Forwarder{ + log: logrus.New(), + router: *httprouter.New(), + cfg: cfg, + activeRequests: make(map[string]context.Context), + ctx: ctx, + } +} + +type mockSemaphoreClient struct { + auth.ClientI + sem types.Semaphores +} + +func (m *mockSemaphoreClient) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { + return m.sem.AcquireSemaphore(ctx, params) +} + +func (m *mockSemaphoreClient) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error { + return m.sem.CancelSemaphoreLease(ctx, lease) +} + +func TestKubernetesConnectionLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + type testCase struct { + name string + connections int + role types.Role + assert require.ErrorAssertionFunc + } + + unlimitedRole, err := types.NewRole("unlimited", types.RoleSpecV5{}) + require.NoError(t, err) + + limitedRole, err := types.NewRole("unlimited", types.RoleSpecV5{ + Options: types.RoleOptions{ + MaxKubernetesConnections: 5, + }, + }) + require.NoError(t, err) + + testCases := []testCase{ + { + name: "unlimited", + connections: 7, + role: unlimitedRole, + assert: require.NoError, + }, + { + name: "limited-success", + connections: 5, + role: limitedRole, + assert: require.NoError, + }, + { + name: "limited-fail", + connections: 6, + role: limitedRole, + assert: require.Error, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + user, err := types.NewUser("bob") + require.NoError(t, err) + user.SetRoles([]string{testCase.role.GetName()}) + + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + + sem := local.NewPresenceService(backend) + client := &mockSemaphoreClient{sem: sem} + forwarder := newTestForwarder(ctx, ForwarderConfig{ + AuthClient: client, + }) + + for i := 0; i < testCase.connections; i++ { + err = forwarder.AcquireConnectionLock(ctx, user.GetName(), services.NewRoleSet(testCase.role)) + if i == testCase.connections-1 { + testCase.assert(t, err) + } + } + }) + } +} diff --git a/lib/services/semaphore.go b/lib/services/semaphore.go index eec3f23978898..d94abec27add7 100644 --- a/lib/services/semaphore.go +++ b/lib/services/semaphore.go @@ -125,7 +125,7 @@ func (l *SemaphoreLock) Renewed() <-chan struct{} { return l.renewalC } -func (l *SemaphoreLock) KeepAlive(ctx context.Context) { +func (l *SemaphoreLock) keepAlive(ctx context.Context) { var nodrop bool var err error lease := l.lease0 @@ -227,7 +227,7 @@ func AcquireSemaphoreWithRetry(ctx context.Context, req AcquireSemaphoreWithRetr } // AcquireSemaphoreLock attempts to acquire and hold a semaphore lease. If successfully acquired, -// background keepalive processes are started and an associated lock handle is returned. Cancelling +// background keepalive processes are started and an associated lock handle is returned. Cancelling // the supplied context releases the semaphore. func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*SemaphoreLock, error) { if err := cfg.CheckAndSetDefaults(); err != nil { @@ -255,6 +255,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph renewalC: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), } + go lock.keepAlive(ctx) return lock, nil } diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 3391f9fd86f6a..f4f270a66583e 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -1175,7 +1175,6 @@ func (s *ServicesTestSuite) SemaphoreFlakiness(c *check.C) { lock, err := services.AcquireSemaphoreLock(cancelCtx, cfg) c.Assert(err, check.IsNil) - go lock.KeepAlive(cancelCtx) for i := 0; i < renewals; i++ { select { @@ -1219,9 +1218,8 @@ func (s *ServicesTestSuite) SemaphoreContention(c *check.C) { wg.Add(1) go func() { defer wg.Done() - lock, err := services.AcquireSemaphoreLock(cancelCtx, cfg) + _, err := services.AcquireSemaphoreLock(cancelCtx, cfg) c.Assert(err, check.IsNil) - go lock.KeepAlive(cancelCtx) }() } wg.Wait() @@ -1259,9 +1257,8 @@ func (s *ServicesTestSuite) SemaphoreConcurrency(c *check.C) { for i := int64(0); i < attempts; i++ { wg.Add(1) go func() { - lock, err := services.AcquireSemaphoreLock(cancelCtx, cfg) + _, err := services.AcquireSemaphoreLock(cancelCtx, cfg) if err == nil { - go lock.KeepAlive(cancelCtx) atomic.AddInt64(&success, 1) } else { atomic.AddInt64(&failure, 1) @@ -1291,7 +1288,6 @@ func (s *ServicesTestSuite) SemaphoreLock(c *check.C) { defer cancel() lock, err := services.AcquireSemaphoreLock(cancelCtx, cfg) c.Assert(err, check.IsNil) - go lock.KeepAlive(cancelCtx) // MaxLeases is 1, so second acquire op fails. _, err = services.AcquireSemaphoreLock(cancelCtx, cfg) @@ -1307,7 +1303,6 @@ func (s *ServicesTestSuite) SemaphoreLock(c *check.C) { cfg.TickRate = time.Millisecond * 50 lock, err = services.AcquireSemaphoreLock(cancelCtx, cfg) c.Assert(err, check.IsNil) - go lock.KeepAlive(cancelCtx) timeout := time.After(time.Second) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index c2d23564c90ef..bfdb1ebad4760 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1020,7 +1020,7 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont } return ctx, trace.Wrap(err) } - go semLock.KeepAlive(ctx) + // ensure that losing the lock closes the connection context. Under normal // conditions, cancellation propagates from the connection context to the // lock, but if we lose the lock due to some error (e.g. poor connectivity