From c341d2bc1588f2e3ed72b3ffdfe405b8eac7fe35 Mon Sep 17 00:00:00 2001 From: Forrest Marshall Date: Tue, 21 Apr 2020 14:01:40 -0700 Subject: [PATCH] fix agent forwarding for multi-session connections Changes the lifetime of agent forwarding to be scoped to the underlying ssh connection, instead of the specific ssh channel which initially passed the agent forwarding request. --- integration/helpers.go | 6 +- lib/multiplexer/multiplexer_test.go | 4 +- lib/reversetunnel/srv.go | 5 +- lib/srv/ctx.go | 65 ++++++----- lib/srv/forward/sshserver.go | 11 +- lib/srv/regular/sshserver.go | 56 +++++----- lib/srv/regular/sshserver_test.go | 17 ++- lib/sshutils/ctx.go | 161 ++++++++++++++++++++++++++++ lib/sshutils/server.go | 13 ++- lib/sshutils/server_test.go | 11 +- 10 files changed, 264 insertions(+), 85 deletions(-) create mode 100644 lib/sshutils/ctx.go diff --git a/integration/helpers.go b/integration/helpers.go index 5ef0c2ddb85c5..3630aee2416d9 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -1265,11 +1265,11 @@ func (s *discardServer) Stop() { s.sshServer.Close() } -func (s *discardServer) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, newChannel ssh.NewChannel) { +func (s *discardServer) HandleNewChan(ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) { channel, reqs, err := newChannel.Accept() if err != nil { - sconn.Close() - conn.Close() + ccx.ServerConn.Close() + ccx.NetConn.Close() return } diff --git a/lib/multiplexer/multiplexer_test.go b/lib/multiplexer/multiplexer_test.go index 272371712aaa5..28549c2db29e2 100644 --- a/lib/multiplexer/multiplexer_test.go +++ b/lib/multiplexer/multiplexer_test.go @@ -80,7 +80,7 @@ func (s *MuxSuite) TestMultiplexing(c *check.C) { defer backend1.Close() called := false - sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) { called = true nch.Reject(ssh.Prohibited, "nothing to see here") }) @@ -381,7 +381,7 @@ func (s *MuxSuite) TestDisableTLS(c *check.C) { defer backend1.Close() called := false - sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) { called = true nch.Reject(ssh.Prohibited, "nothing to see here") }) diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 997b0af369a16..d8ff22f68f2ca 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -527,11 +527,12 @@ func (s *server) Shutdown(ctx context.Context) error { return s.srv.Shutdown(ctx) } -func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) { +func (s *server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { // Apply read/write timeouts to the server connection. - conn = utils.ObeyIdleTimeout(conn, + conn := utils.ObeyIdleTimeout(ccx.NetConn, s.offlineThreshold, "reverse tunnel server") + sconn := ccx.ServerConn channelType := nch.ChannelType() switch channelType { diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 322fd6a07e2f1..c64e0fe3478f1 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -174,6 +174,8 @@ type ServerContext struct { sync.RWMutex + Parent *sshutils.ConnectionContext + // env is a list of environment variables passed to the session. env map[string]string @@ -186,12 +188,6 @@ type ServerContext struct { // term holds PTY if it was requested by the session. term Terminal - // agent is a client to remote SSH agent. - agent agent.Agent - - // agentCh is SSH channel using SSH agent protocol. - agentChannel ssh.Channel - // session holds the active session (if there's an active one). session *session @@ -291,7 +287,7 @@ type ServerContext struct { // NewServerContext creates a new *ServerContext which is used to pass and // manage resources. -func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext) (*ServerContext, error) { +func NewServerContext(ccx *sshutils.ConnectionContext, srv Server, identityContext IdentityContext) (*ServerContext, error) { clusterConfig, err := srv.GetAccessPoint().GetClusterConfig() if err != nil { return nil, trace.Wrap(err) @@ -300,13 +296,15 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity cancelContext, cancel := context.WithCancel(context.TODO()) ctx := &ServerContext{ + Parent: ccx, id: int(atomic.AddInt32(&ctxID, int32(1))), env: make(map[string]string), srv: srv, - Conn: conn, + Connection: ccx.NetConn, + Conn: ccx.ServerConn, ExecResultCh: make(chan ExecResult, 10), SubsystemResultCh: make(chan SubsystemResult, 10), - ClusterName: conn.Permissions.Extensions[utils.CertTeleportClusterName], + ClusterName: ccx.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName], ClusterConfig: clusterConfig, Identity: identityContext, clientIdleTimeout: identityContext.RoleSet.AdjustClientIdleTimeout(clusterConfig.GetClientIdleTimeout()), @@ -320,8 +318,8 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity } fields := log.Fields{ - "local": conn.LocalAddr(), - "remote": conn.RemoteAddr(), + "local": ctx.Conn.LocalAddr(), + "remote": ctx.Conn.RemoteAddr(), "login": ctx.Identity.Login, "teleportUser": ctx.Identity.TeleportUser, "id": ctx.id, @@ -343,7 +341,7 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity ClientIdleTimeout: ctx.clientIdleTimeout, Clock: ctx.srv.GetClock(), Tracker: ctx, - Conn: conn, + Conn: ctx.Conn, Context: cancelContext, TeleportUser: ctx.Identity.TeleportUser, Login: ctx.Identity.Login, @@ -374,6 +372,9 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity ctx.AddCloser(ctx.contr) ctx.AddCloser(ctx.contw) + // gather environment variables from parent. + ctx.ImportParentEnv() + return ctx, nil } @@ -447,30 +448,22 @@ func (c *ServerContext) AddCloser(closer io.Closer) { c.closers = append(c.closers, closer) } -// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent. +// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent, +// or nil if no agent is available in this context. func (c *ServerContext) GetAgent() agent.Agent { - c.RLock() - defer c.RUnlock() - return c.agent + if c.Parent == nil { + return nil + } + return c.Parent.GetAgent() } -// GetAgentChannel returns the channel over which communication with the agent occurs. +// GetAgentChannel returns the channel over which communication with the agent occurs, +// or nil if no agent is available in this context. func (c *ServerContext) GetAgentChannel() ssh.Channel { - c.RLock() - defer c.RUnlock() - return c.agentChannel -} - -// SetAgent sets the agent and channel over which communication with the agent occurs. -func (c *ServerContext) SetAgent(a agent.Agent, channel ssh.Channel) { - c.Lock() - defer c.Unlock() - if c.agentChannel != nil { - c.Infof("closing previous agent channel") - c.agentChannel.Close() + if c.Parent == nil { + return nil } - c.agentChannel = channel - c.agent = a + return c.Parent.GetAgentChannel() } // GetTerm returns a Terminal. @@ -500,6 +493,12 @@ func (c *ServerContext) GetEnv(key string) (string, bool) { return val, ok } +// ImportParentEnv is used to re-synchronize env vars after +// parent context has been updated. +func (c *ServerContext) ImportParentEnv() { + c.Parent.ExportEnv(c.env) +} + // takeClosers returns all resources that should be closed and sets the properties to null // we do this to avoid calling Close() under lock to avoid potential deadlocks func (c *ServerContext) takeClosers() []io.Closer { @@ -512,10 +511,6 @@ func (c *ServerContext) takeClosers() []io.Closer { closers = append(closers, c.term) c.term = nil } - if c.agentChannel != nil { - closers = append(closers, c.agentChannel) - c.agentChannel = nil - } closers = append(closers, c.closers...) c.closers = nil return closers diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 61c23f5bbff66..35da17df8ad4d 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -88,6 +88,10 @@ type Server struct { // forwarding, subsystems. remoteClient *ssh.Client + // connectionContext is used to construct ServerContext instances + // and supports registration of connection-scoped resource closers. + connectionContext *sshutils.ConnectionContext + // identityContext holds identity information about the user that has // authenticated on sconn (like system login, Teleport username, roles). identityContext srv.IdentityContext @@ -435,6 +439,8 @@ func (s *Server) Serve() { } s.sconn = sconn + s.connectionContext = sshutils.NewConnectionContext(s.serverConn, s.sconn) + // Take connection and extract identity information for the user from it. s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn) if err != nil { @@ -488,6 +494,7 @@ func (s *Server) Close() error { s.serverConn, s.targetConn, s.remoteClient, + s.connectionContext, } var errs []error @@ -646,7 +653,7 @@ func (s *Server) handleChannel(nch ssh.NewChannel) { func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext) + ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext) if err != nil { ctx.Errorf("Unable to create connection context: %v.", err) ch.Stderr().Write([]byte("Unable to create connection context.")) @@ -713,7 +720,7 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) { // There is no need for the forwarding server to initiate disconnects, // based on teleport business logic, because this logic is already // done on the server's terminating side. - ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext) + ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext) if err != nil { ctx.Errorf("Unable to create connection context: %v.", err) ch.Stderr().Write([]byte("Unable to create connection context.")) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 0d9254d7f5166..33baeb04a5234 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -761,15 +761,18 @@ func (s *Server) serveAgent(ctx *srv.ServerContext) error { } // start an agent on a unix socket - agentServer := &teleagent.AgentServer{Agent: ctx.GetAgent()} + agentServer := &teleagent.AgentServer{Agent: ctx.Parent.GetAgent()} err = agentServer.ListenUnixSocket(socketPath, uid, gid, 0600) if err != nil { return trace.Wrap(err) } - ctx.SetEnv(teleport.SSHAuthSock, socketPath) - ctx.SetEnv(teleport.SSHAgentPID, fmt.Sprintf("%v", pid)) - ctx.AddCloser(agentServer) - ctx.AddCloser(dirCloser) + ctx.Parent.SetEnv(teleport.SSHAuthSock, socketPath) + ctx.Parent.SetEnv(teleport.SSHAgentPID, fmt.Sprintf("%v", pid)) + ctx.Parent.AddCloser(agentServer) + ctx.Parent.AddCloser(dirCloser) + // ensure that SSHAuthSock and SSHAgentPID are imported into + // the current child context. + ctx.ImportParentEnv() ctx.Debugf("Opened agent channel for Teleport user %v and socket %v.", ctx.Identity.TeleportUser, socketPath) go agentServer.Serve() @@ -816,8 +819,8 @@ func (s *Server) HandleRequest(r *ssh.Request) { } // HandleNewChan is called when new channel is opened -func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) { - identityContext, err := s.authHandlers.CreateIdentityContext(sconn) +func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { + identityContext, err := s.authHandlers.CreateIdentityContext(ccx.ServerConn) if err != nil { nch.Reject(ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err)) return @@ -841,7 +844,7 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleProxyJump(wconn, sconn, identityContext, ch, *req) + go s.handleProxyJump(ccx, identityContext, ch, *req) return // Channels of type "session" handle requests that are involved in running // commands on a server. In the case of proxy mode subsystem and agent @@ -853,7 +856,7 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests) + go s.handleSessionRequests(ccx, identityContext, ch, requests) return default: nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) @@ -871,7 +874,7 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests) + go s.handleSessionRequests(ccx, identityContext, ch, requests) // Channels of type "direct-tcpip" handles request for port forwarding. case teleport.ChanDirectTCPIP: req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData()) @@ -886,23 +889,22 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleDirectTCPIPRequest(wconn, sconn, identityContext, ch, req) + go s.handleDirectTCPIPRequest(ccx, identityContext, ch, req) default: nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) } } // handleDirectTCPIPRequest handles port forwarding requests. -func (s *Server) handleDirectTCPIPRequest(wconn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) { +func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - ctx, err := srv.NewServerContext(s, sconn, identityContext) + ctx, err := srv.NewServerContext(ccx, s, identityContext) if err != nil { log.Errorf("Unable to create connection context: %v.", err) channel.Stderr().Write([]byte("Unable to create connection context.")) return } - ctx.Connection = wconn ctx.IsTestStub = s.isTestStub ctx.AddCloser(channel) ctx.ChannelType = teleport.ChanDirectTCPIP @@ -988,24 +990,23 @@ func (s *Server) handleDirectTCPIPRequest(wconn net.Conn, sconn *ssh.ServerConn, events.PortForwardSuccess: true, events.EventLogin: ctx.Identity.Login, events.EventUser: ctx.Identity.TeleportUser, - events.LocalAddr: sconn.LocalAddr().String(), - events.RemoteAddr: sconn.RemoteAddr().String(), + events.LocalAddr: ctx.Conn.LocalAddr().String(), + events.RemoteAddr: ctx.Conn.RemoteAddr().String(), }) } // handleSessionRequests handles out of band session requests once the session // channel has been created this function's loop handles all the "exec", // "subsystem" and "shell" requests. -func (s *Server) handleSessionRequests(conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) { +func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) { // Create context for this channel. This context will be closed when the // session request is complete. - ctx, err := srv.NewServerContext(s, sconn, identityContext) + ctx, err := srv.NewServerContext(ccx, s, identityContext) if err != nil { log.Errorf("Unable to create connection context: %v.", err) ch.Stderr().Write([]byte("Unable to create connection context.")) return } - ctx.Connection = conn ctx.IsTestStub = s.isTestStub ctx.AddCloser(ch) ctx.ChannelType = teleport.ChanSession @@ -1028,7 +1029,7 @@ func (s *Server) handleSessionRequests(conn net.Conn, sconn *ssh.ServerConn, ide // closeContext which signals the server to shutdown. go srv.StartKeepAliveLoop(srv.KeepAliveParams{ Conns: []srv.RequestSender{ - sconn, + ctx.Conn, }, Interval: clusterConfig.GetKeepAliveInterval(), MaxCount: clusterConfig.GetKeepAliveCountMax(), @@ -1062,7 +1063,7 @@ func (s *Server) handleSessionRequests(conn net.Conn, sconn *ssh.ServerConn, ide case req := <-in: if req == nil { // this will happen when the client closes/drops the connection - ctx.Debugf("Client %v disconnected.", sconn.RemoteAddr()) + ctx.Debugf("Client %v disconnected.", ctx.Conn.RemoteAddr()) return } if err := s.dispatch(ch, req, ctx); err != nil { @@ -1171,7 +1172,7 @@ func (s *Server) handleAgentForwardNode(req *ssh.Request, ctx *srv.ServerContext } // save the agent in the context so it can be used later - ctx.SetAgent(agent.NewClient(authChannel), authChannel) + ctx.Parent.SetAgent(agent.NewClient(authChannel), authChannel) // serve an agent on a unix socket on this node err = s.serveAgent(ctx) @@ -1209,7 +1210,7 @@ func (s *Server) handleAgentForwardProxy(req *ssh.Request, ctx *srv.ServerContex // Save the agent so it can be used when making a proxy subsystem request // later. It will also be used when building a remote connection to the // target node. - ctx.SetAgent(agent.NewClient(authChannel), authChannel) + ctx.Parent.SetAgent(agent.NewClient(authChannel), authChannel) return nil } @@ -1304,16 +1305,15 @@ func (s *Server) handleVersionRequest(req *ssh.Request) { } // handleProxyJump handles ProxyJump request that is executed via direct tcp-ip dial on the proxy -func (s *Server) handleProxyJump(conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req sshutils.DirectTCPIPReq) { +func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, req sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when the // session request is complete. - ctx, err := srv.NewServerContext(s, sconn, identityContext) + ctx, err := srv.NewServerContext(ccx, s, identityContext) if err != nil { log.Errorf("Unable to create connection context: %v.", err) ch.Stderr().Write([]byte("Unable to create connection context.")) return } - ctx.Connection = conn ctx.IsTestStub = s.isTestStub ctx.AddCloser(ch) defer ctx.Close() @@ -1364,7 +1364,7 @@ func (s *Server) handleProxyJump(conn net.Conn, sconn *ssh.ServerConn, identityC // closeContext which signals the server to shutdown. go srv.StartKeepAliveLoop(srv.KeepAliveParams{ Conns: []srv.RequestSender{ - sconn, + ctx.Conn, }, Interval: clusterConfig.GetKeepAliveInterval(), MaxCount: clusterConfig.GetKeepAliveCountMax(), @@ -1391,7 +1391,7 @@ func (s *Server) handleProxyJump(conn net.Conn, sconn *ssh.ServerConn, identityC return } - if err := subsys.Start(sconn, ch, &ssh.Request{}, ctx); err != nil { + if err := subsys.Start(ctx.Conn, ch, &ssh.Request{}, ctx); err != nil { log.Errorf("Unable to start proxy subsystem: %v.", err) ch.Stderr().Write([]byte("Unable to start proxy subsystem.")) return diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 221037d3c428c..49c3388b97184 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -352,8 +352,21 @@ func (s *SrvSuite) TestAgentForward(c *C) { err = client.Close() c.Assert(err, IsNil) - // make sure the socket is gone after we closed the session - se.Close() + // make sure the socket persists after the session is closed. + // (agents are started from specific sessions, but apply to all + // sessions on the connection). + err = se.Close() + c.Assert(err, IsNil) + // Pause to allow closure to propagate. + time.Sleep(150 * time.Millisecond) + _, err = net.Dial("unix", socketPath) + c.Assert(err, IsNil) + + // make sure the socket is gone after we closed the connection. + err = s.clt.Close() + c.Assert(err, IsNil) + // clt must be nullified to prevent double-close during test cleanup + s.clt = nil for i := 0; i < 4; i++ { _, err = net.Dial("unix", socketPath) if err != nil { diff --git a/lib/sshutils/ctx.go b/lib/sshutils/ctx.go new file mode 100644 index 0000000000000..1578aac0d8989 --- /dev/null +++ b/lib/sshutils/ctx.go @@ -0,0 +1,161 @@ +/* +Copyright 2020 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sshutils + +import ( + "io" + "net" + "sync" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "github.com/gravitational/trace" +) + +// ConnectionContext manages connection-level state. +type ConnectionContext struct { + // NetConn is the base connection object. + NetConn net.Conn + + // ServerConn is authenticated ssh connection. + ServerConn *ssh.ServerConn + + // mu protects the rest of the state + mu sync.RWMutex + + // env holds environment variables which should be + // set for all channels. + env map[string]string + + // agent is a client to remote SSH agent. + agent agent.Agent + + // agentCh is SSH channel using SSH agent protocol. + agentChannel ssh.Channel + // closers is a list of io.Closer that will be called when session closes + // this is handy as sometimes client closes session, in this case resources + // will be properly closed and deallocated, otherwise they could be kept hanging. + closers []io.Closer +} + +// NewConnectionContext creates a new ConnectionContext instance. +func NewConnectionContext(nconn net.Conn, sconn *ssh.ServerConn) *ConnectionContext { + return &ConnectionContext{ + NetConn: nconn, + ServerConn: sconn, + env: make(map[string]string), + } +} + +// SetEnv sets a environment variable within this context. +func (c *ConnectionContext) SetEnv(key, val string) { + c.mu.Lock() + defer c.mu.Unlock() + c.env[key] = val +} + +// GetEnv returns a environment variable within this context. +func (c *ConnectionContext) GetEnv(key string) (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + val, ok := c.env[key] + return val, ok +} + +// ExportEnv writes all env vars to supplied map (used to configure +// env of child contexts). +func (c *ConnectionContext) ExportEnv(m map[string]string) { + c.mu.RLock() + defer c.mu.RUnlock() + for key, val := range c.env { + m[key] = val + } +} + +// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent, +// or nil if no agent is available in this context. +func (c *ConnectionContext) GetAgent() agent.Agent { + c.mu.RLock() + defer c.mu.RUnlock() + return c.agent +} + +// GetAgentChannel returns the channel over which communication with the agent occurs, +// or nil if no agent is available in this context. +func (c *ConnectionContext) GetAgentChannel() ssh.Channel { + c.mu.RLock() + defer c.mu.RUnlock() + return c.agentChannel +} + +// SetAgent sets the agent and channel over which communication with the agent occurs. +func (c *ConnectionContext) SetAgent(a agent.Agent, channel ssh.Channel) { + c.mu.Lock() + defer c.mu.Unlock() + if c.agentChannel != nil { + c.agentChannel.Close() + } + c.agentChannel = channel + c.agent = a +} + +// AddCloser adds any closer in ctx that will be called +// when the underlying connection is closed. +func (c *ConnectionContext) AddCloser(closer io.Closer) { + c.mu.Lock() + defer c.mu.Unlock() + c.closers = append(c.closers, closer) +} + +// takeClosers returns all resources that should be closed and sets the properties to null +// we do this to avoid calling Close() under lock to avoid potential deadlocks +func (c *ConnectionContext) takeClosers() []io.Closer { + // this is done to avoid any operation holding the lock for too long + c.mu.Lock() + defer c.mu.Unlock() + + closers := c.closers + c.closers = nil + if c.agentChannel != nil { + closers = append(closers, c.agentChannel) + c.agentChannel = nil + } + return closers +} + +// Close closes associated resources (e.g. agent channel). +func (c *ConnectionContext) Close() error { + var errs []error + + closers := c.takeClosers() + + for _, cl := range closers { + if cl == nil { + continue + } + + err := cl.Close() + if err == nil { + continue + } + + errs = append(errs, err) + } + + return trace.NewAggregate(errs...) +} diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index c144be213dbbf..d1eed80c9d020 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -432,6 +432,9 @@ func (s *Server) HandleConnection(conn net.Conn) { defer keepAliveTick.Stop() keepAlivePayload := [8]byte{0} + ccx := NewConnectionContext(wconn, sconn) + defer ccx.Close() + for { select { // handle out of band ssh requests @@ -450,7 +453,7 @@ func (s *Server) HandleConnection(conn net.Conn) { connClosed() return } - go s.newChanHandler.HandleNewChan(wconn, sconn, nch) + go s.newChanHandler.HandleNewChan(ccx, nch) // send keepalive pings to the clients case <-keepAliveTick.C: const wantReply = true @@ -470,13 +473,13 @@ func (f RequestHandlerFunc) HandleRequest(r *ssh.Request) { } type NewChanHandler interface { - HandleNewChan(net.Conn, *ssh.ServerConn, ssh.NewChannel) + HandleNewChan(*ConnectionContext, ssh.NewChannel) } -type NewChanHandlerFunc func(net.Conn, *ssh.ServerConn, ssh.NewChannel) +type NewChanHandlerFunc func(*ConnectionContext, ssh.NewChannel) -func (f NewChanHandlerFunc) HandleNewChan(conn net.Conn, sshConn *ssh.ServerConn, ch ssh.NewChannel) { - f(conn, sshConn, ch) +func (f NewChanHandlerFunc) HandleNewChan(ccx *ConnectionContext, ch ssh.NewChannel) { + f(ccx, ch) } type AuthMethods struct { diff --git a/lib/sshutils/server_test.go b/lib/sshutils/server_test.go index 621cff948e628..e7d8cfc81ff1a 100644 --- a/lib/sshutils/server_test.go +++ b/lib/sshutils/server_test.go @@ -19,7 +19,6 @@ package sshutils import ( "context" "fmt" - "net" "testing" "time" @@ -52,7 +51,7 @@ func (s *ServerSuite) SetUpSuite(c *check.C) { func (s *ServerSuite) TestStartStop(c *check.C) { called := false - fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + fn := NewChanHandlerFunc(func(_ *ConnectionContext, nch ssh.NewChannel) { called = true nch.Reject(ssh.Prohibited, "nothing to see here") }) @@ -86,13 +85,13 @@ func (s *ServerSuite) TestStartStop(c *check.C) { // TestShutdown tests graceul shutdown feature func (s *ServerSuite) TestShutdown(c *check.C) { closeContext, cancel := context.WithCancel(context.TODO()) - fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + fn := NewChanHandlerFunc(func(ccx *ConnectionContext, nch ssh.NewChannel) { ch, _, err := nch.Accept() c.Assert(err, check.IsNil) defer ch.Close() select { case <-closeContext.Done(): - conn.Close() + ccx.ServerConn.Close() } }) @@ -136,7 +135,7 @@ func (s *ServerSuite) TestShutdown(c *check.C) { } func (s *ServerSuite) TestConfigureCiphers(c *check.C) { - fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + fn := NewChanHandlerFunc(func(_ *ConnectionContext, nch ssh.NewChannel) { nch.Reject(ssh.Prohibited, "nothing to see here") }) @@ -182,7 +181,7 @@ func (s *ServerSuite) TestHostSignerFIPS(c *check.C) { _, ellipticSigner, err := utils.CreateEllipticCertificate("foo", ssh.HostCert) c.Assert(err, check.IsNil) - newChanHandler := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + newChanHandler := NewChanHandlerFunc(func(_ *ConnectionContext, nch ssh.NewChannel) { nch.Reject(ssh.Prohibited, "nothing to see here") })