From 88cd16a27bf46a14d298b6a7d48b81cb1908f1c9 Mon Sep 17 00:00:00 2001 From: Forrest Marshall Date: Fri, 5 Jun 2020 17:31:11 -0700 Subject: [PATCH] Make agent channel setup lazy. Changes agent channel setup behavior to be consistent openssh by having servers lazily request agent channels when they are needed, rather than immediately starting a single connection-wide channel as soon as forwarding is requested. Fixes an issue introduced in #3613 which caused openssh clients to hang on exit due to persistent agent channel. --- integration/helpers.go | 9 +- lib/multiplexer/multiplexer_test.go | 5 +- lib/reversetunnel/api.go | 7 +- lib/reversetunnel/localsite.go | 16 +- lib/reversetunnel/remotesite.go | 16 +- lib/reversetunnel/srv.go | 2 +- lib/reversetunnel/track/tracker_test.go | 23 ++- lib/srv/authhandlers.go | 4 +- lib/srv/ctx.go | 188 +++++++++++------------- lib/srv/exec.go | 8 +- lib/srv/exec_test.go | 12 +- lib/srv/forward/sshserver.go | 136 ++++++++--------- lib/srv/regular/proxy.go | 119 ++++++++------- lib/srv/regular/proxy_test.go | 91 ++++++------ lib/srv/regular/sshserver.go | 186 +++++++++++------------ lib/srv/sess.go | 16 +- lib/sshutils/ctx.go | 119 +++++++++------ lib/sshutils/server.go | 18 ++- lib/sshutils/server_test.go | 8 +- lib/teleagent/agent.go | 47 +++++- 20 files changed, 553 insertions(+), 477 deletions(-) diff --git a/integration/helpers.go b/integration/helpers.go index 17be6ee433a1d..ab9c13f20c8a0 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -1262,7 +1262,7 @@ func (s *discardServer) Stop() { s.sshServer.Close() } -func (s *discardServer) HandleNewChan(ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) { +func (s *discardServer) HandleNewChan(_ context.Context, ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) { channel, reqs, err := newChannel.Accept() if err != nil { ccx.ServerConn.Close() @@ -1400,10 +1400,13 @@ func createAgent(me *user.User, privateKeyByte []byte, certificateBytes []byte) } // create a (unstarted) agent and add the key to it - teleAgent := teleagent.NewServer() - if err := teleAgent.Add(agentKey); err != nil { + keyring := agent.NewKeyring() + if err := keyring.Add(agentKey); err != nil { return nil, "", "", trace.Wrap(err) } + teleAgent := teleagent.NewServer(func() (teleagent.Agent, error) { + return teleagent.NopCloser(keyring), nil + }) // start the SSH agent err = teleAgent.ListenUnixSocket(sockPath, uid, gid, 0600) diff --git a/lib/multiplexer/multiplexer_test.go b/lib/multiplexer/multiplexer_test.go index 08db0d97b36c5..7ae0258d0853c 100644 --- a/lib/multiplexer/multiplexer_test.go +++ b/lib/multiplexer/multiplexer_test.go @@ -17,6 +17,7 @@ limitations under the License. package multiplexer import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -80,7 +81,7 @@ func (s *MuxSuite) TestMultiplexing(c *check.C) { defer backend1.Close() called := false - sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) { + sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) { called = true err := nch.Reject(ssh.Prohibited, "nothing to see here") c.Assert(err, check.IsNil) @@ -380,7 +381,7 @@ func (s *MuxSuite) TestDisableTLS(c *check.C) { defer backend1.Close() called := false - sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) { + sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) { called = true err := nch.Reject(ssh.Prohibited, "nothing to see here") c.Assert(err, check.IsNil) diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 644ee49138d60..61e22b100c3eb 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -22,9 +22,8 @@ import ( "net" "time" - "golang.org/x/crypto/ssh/agent" - "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/teleagent" ) // DialParams is a list of parameters used to Dial to a node within a cluster. @@ -35,9 +34,9 @@ type DialParams struct { // To is the destination address. To net.Addr - // UserAgent is SSH agent used to connect to the remote host. Used by the + // GetUserAgent gets an SSH agent for use in connecting to the remote host. Used by the // forwarding proxy. - UserAgent agent.Agent + GetUserAgent teleagent.Getter // Address is used by the forwarding proxy to generate a host certificate for // the target node. This is needed because while dialing occurs via IP diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index a03abcef1e91c..2733a78327ad9 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -170,9 +170,6 @@ func (s *localSite) Dial(params DialParams) (net.Conn, error) { return nil, trace.Wrap(err) } if clusterConfig.GetSessionRecording() == services.RecordAtProxy { - if params.UserAgent == nil { - return nil, trace.BadParameter("user agent missing") - } return s.dialWithAgent(params) } @@ -195,18 +192,29 @@ func (s *localSite) DialTCP(params DialParams) (net.Conn, error) { func (s *localSite) IsClosed() bool { return false } func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) { + if params.GetUserAgent == nil { + return nil, trace.BadParameter("user agent getter missing") + } s.log.Debugf("Dialing with an agent from %v to %v.", params.From, params.To) + // request user agent connection + userAgent, err := params.GetUserAgent() + if err != nil { + return nil, trace.Wrap(err) + } + // If server ID matches a node that has self registered itself over the tunnel, // return a connection to that node. Otherwise net.Dial to the target host. targetConn, useTunnel, err := s.getConn(params) if err != nil { + userAgent.Close() return nil, trace.Wrap(err) } // Get a host certificate for the forwarding node from the cache. hostCertificate, err := s.certificateCache.GetHostCertificate(params.Address, params.Principals) if err != nil { + userAgent.Close() return nil, trace.Wrap(err) } @@ -215,7 +223,7 @@ func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) { // once conn is closed. serverConfig := forward.ServerConfig{ AuthClient: s.client, - UserAgent: params.UserAgent, + UserAgent: userAgent, TargetConn: targetConn, SrcAddr: params.From, DstAddr: params.To, diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index e3f3e85295ff0..ca9bf024187ff 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -509,9 +509,6 @@ func (s *remoteSite) Dial(params DialParams) (net.Conn, error) { // If the proxy is in recording mode use the agent to dial and build a // in-memory forwarding server. if clusterConfig.GetSessionRecording() == services.RecordAtProxy { - if params.UserAgent == nil { - return nil, trace.BadParameter("user agent missing") - } return s.dialWithAgent(params) } return s.DialTCP(params) @@ -532,11 +529,21 @@ func (s *remoteSite) DialTCP(params DialParams) (net.Conn, error) { } func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { + if params.GetUserAgent == nil { + return nil, trace.BadParameter("user agent getter missing") + } s.Debugf("Dialing with an agent from %v to %v.", params.From, params.To) + // request user agent connection + userAgent, err := params.GetUserAgent() + if err != nil { + return nil, trace.Wrap(err) + } + // Get a host certificate for the forwarding node from the cache. hostCertificate, err := s.certificateCache.GetHostCertificate(params.Address, params.Principals) if err != nil { + userAgent.Close() return nil, trace.Wrap(err) } @@ -545,6 +552,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { ServerID: params.ServerID, }) if err != nil { + userAgent.Close() return nil, trace.Wrap(err) } @@ -556,7 +564,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) { // session gets recorded in the local cluster instead of the remote cluster. serverConfig := forward.ServerConfig{ AuthClient: s.localClient, - UserAgent: params.UserAgent, + UserAgent: userAgent, TargetConn: targetConn, SrcAddr: params.From, DstAddr: params.To, diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 2c504c08292cc..35a8c4cb55e9e 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -527,7 +527,7 @@ func (s *server) Shutdown(ctx context.Context) error { return s.srv.Shutdown(ctx) } -func (s *server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { +func (s *server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { // Apply read/write timeouts to the server connection. conn := utils.ObeyIdleTimeout(ccx.NetConn, s.offlineThreshold, diff --git a/lib/reversetunnel/track/tracker_test.go b/lib/reversetunnel/track/tracker_test.go index 93eb600bf0044..2b548692ca1a4 100644 --- a/lib/reversetunnel/track/tracker_test.go +++ b/lib/reversetunnel/track/tracker_test.go @@ -78,11 +78,15 @@ func (s *simpleTestProxies) GetRandProxy() (p testProxy, ok bool) { } func (s *simpleTestProxies) Discover(tracker *Tracker, lease Lease) (ok bool) { - defer lease.Release() proxy, ok := s.GetRandProxy() if !ok { panic("discovery called with no available proxies") } + return s.ProxyLoop(tracker, lease, proxy) +} + +func (s *simpleTestProxies) ProxyLoop(tracker *Tracker, lease Lease, proxy testProxy) (ok bool) { + defer lease.Release() timeout := time.After(proxy.life) ok = tracker.WithProxy(func() { ticker := time.NewTicker(jitter(time.Millisecond * 100)) @@ -165,7 +169,7 @@ Discover: break Discover } case <-timeoutC: - panic("timeout") + c.Fatal("timeout") } } } @@ -193,7 +197,14 @@ Loop0: select { case lease := <-tracker.Acquire(): c.Assert(lease.Key().(Key), check.DeepEquals, key) - go proxies.Discover(tracker, lease) + // get our "discovered" proxy in the foreground + // to prevent race with the call to RemoveRandProxies + // that comes after this loop. + proxy, ok := proxies.GetRandProxy() + if !ok { + c.Fatal("failed to get test proxy") + } + go proxies.ProxyLoop(tracker, lease, proxy) case <-ticker.C: counts := tracker.wp.Get(key) c.Logf("Counts0: %+v", counts) @@ -201,7 +212,7 @@ Loop0: break Loop0 } case <-timeoutC: - panic("timeout") + c.Fatal("timeout") } } proxies.RemoveRandProxies(proxyCount) @@ -215,7 +226,7 @@ Loop1: break Loop1 } case <-timeoutC: - panic("timeout") + c.Fatal("timeout") } } proxies.AddRandProxies(proxyCount, minConnB, maxConnB) @@ -231,7 +242,7 @@ Loop2: break Loop2 } case <-timeoutC: - panic("timeout") + c.Fatal("timeout") } } } diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 63a4f1006a312..51bb9698691ed 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -115,8 +115,8 @@ func (h *AuthHandlers) CheckPortForward(addr string, ctx *ServerContext) error { events.PortForwardErr: systemErrorMessage, events.EventLogin: ctx.Identity.Login, events.EventUser: ctx.Identity.TeleportUser, - events.LocalAddr: ctx.Conn.LocalAddr().String(), - events.RemoteAddr: ctx.Conn.RemoteAddr().String(), + events.LocalAddr: ctx.ServerConn.LocalAddr().String(), + events.RemoteAddr: ctx.ServerConn.RemoteAddr().String(), }); err != nil { h.Warnf("Failed to emit port forward deny audit event: %v", err) } diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index b465470f90dfe..4815021182ddd 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -27,7 +27,6 @@ import ( "time" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" @@ -170,12 +169,13 @@ func (c IdentityContext) GetCertificate() (*ssh.Certificate, error) { // used to access resources on the underlying server. SessionContext can also // be used to attach resources that should be closed once the session closes. type ServerContext struct { + // ConnectionContext is the parent context which manages connection-level + // resources. + *sshutils.ConnectionContext *log.Entry sync.RWMutex - Parent *sshutils.ConnectionContext - // env is a list of environment variables passed to the session. env map[string]string @@ -196,12 +196,6 @@ type ServerContext struct { // will be properly closed and deallocated, otherwise they could be kept hanging. closers []io.Closer - // Conn is the underlying *ssh.ServerConn. - Conn *ssh.ServerConn - - // Connection is the underlying net.Conn for the connection. - Connection net.Conn - // Identity holds the identity of the user that is currently logged in on // the Conn. Identity IdentityContext @@ -246,9 +240,6 @@ type ServerContext struct { // on client inactivity, set to 0 if not setup clientIdleTimeout time.Duration - // cancelContext signals closure to all outstanding operations - cancelContext context.Context - // cancel is called whenever server context is closed cancel context.CancelFunc @@ -286,96 +277,102 @@ type ServerContext struct { } // NewServerContext creates a new *ServerContext which is used to pass and -// manage resources. -func NewServerContext(ccx *sshutils.ConnectionContext, srv Server, identityContext IdentityContext) (*ServerContext, error) { +// manage resources, and an associated context.Context which is canceled when +// the ServerContext is closed. The ctx parameter should be a child of the ctx +// associated with the scope of the parent ConnectionContext to ensure that +// cancellation of the ConnectionContext propagates to the ServerContext. +func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext) (context.Context, *ServerContext, error) { clusterConfig, err := srv.GetAccessPoint().GetClusterConfig() if err != nil { - return nil, trace.Wrap(err) + return nil, nil, trace.Wrap(err) } - cancelContext, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) - ctx := &ServerContext{ - Parent: ccx, + child := &ServerContext{ + ConnectionContext: parent, id: int(atomic.AddInt32(&ctxID, int32(1))), env: make(map[string]string), srv: srv, - Connection: ccx.NetConn, - Conn: ccx.ServerConn, ExecResultCh: make(chan ExecResult, 10), SubsystemResultCh: make(chan SubsystemResult, 10), - ClusterName: ccx.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName], + ClusterName: parent.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName], ClusterConfig: clusterConfig, Identity: identityContext, clientIdleTimeout: identityContext.RoleSet.AdjustClientIdleTimeout(clusterConfig.GetClientIdleTimeout()), - cancelContext: cancelContext, cancel: cancel, } disconnectExpiredCert := identityContext.RoleSet.AdjustDisconnectExpiredCert(clusterConfig.GetDisconnectExpiredCert()) if !identityContext.CertValidBefore.IsZero() && disconnectExpiredCert { - ctx.disconnectExpiredCert = identityContext.CertValidBefore + child.disconnectExpiredCert = identityContext.CertValidBefore } fields := log.Fields{ - "local": ctx.Conn.LocalAddr(), - "remote": ctx.Conn.RemoteAddr(), - "login": ctx.Identity.Login, - "teleportUser": ctx.Identity.TeleportUser, - "id": ctx.id, + "local": child.ServerConn.LocalAddr(), + "remote": child.ServerConn.RemoteAddr(), + "login": child.Identity.Login, + "teleportUser": child.Identity.TeleportUser, + "id": child.id, } - if !ctx.disconnectExpiredCert.IsZero() { - fields["cert"] = ctx.disconnectExpiredCert + if !child.disconnectExpiredCert.IsZero() { + fields["cert"] = child.disconnectExpiredCert } - if ctx.clientIdleTimeout != 0 { - fields["idle"] = ctx.clientIdleTimeout + if child.clientIdleTimeout != 0 { + fields["idle"] = child.clientIdleTimeout } - ctx.Entry = log.WithFields(log.Fields{ + child.Entry = log.WithFields(log.Fields{ trace.Component: srv.Component(), trace.ComponentFields: fields, }) - if !ctx.disconnectExpiredCert.IsZero() || ctx.clientIdleTimeout != 0 { + if !child.disconnectExpiredCert.IsZero() || child.clientIdleTimeout != 0 { mon, err := NewMonitor(MonitorConfig{ - DisconnectExpiredCert: ctx.disconnectExpiredCert, - ClientIdleTimeout: ctx.clientIdleTimeout, - Clock: ctx.srv.GetClock(), - Tracker: ctx, - Conn: ctx.Conn, - Context: cancelContext, - TeleportUser: ctx.Identity.TeleportUser, - Login: ctx.Identity.Login, - ServerID: ctx.srv.ID(), - Audit: ctx.srv.GetAuditLog(), - Entry: ctx.Entry, + DisconnectExpiredCert: child.disconnectExpiredCert, + ClientIdleTimeout: child.clientIdleTimeout, + Clock: child.srv.GetClock(), + Tracker: child, + Conn: child.ServerConn, + Context: ctx, + TeleportUser: child.Identity.TeleportUser, + Login: child.Identity.Login, + ServerID: child.srv.ID(), + Audit: child.srv.GetAuditLog(), + Entry: child.Entry, }) if err != nil { - ctx.Close() - return nil, trace.Wrap(err) + child.Close() + return nil, nil, trace.Wrap(err) } go mon.Start() } // Create pipe used to send command to child process. - ctx.cmdr, ctx.cmdw, err = os.Pipe() + child.cmdr, child.cmdw, err = os.Pipe() if err != nil { - return nil, trace.Wrap(err) + child.Close() + return nil, nil, trace.Wrap(err) } - ctx.AddCloser(ctx.cmdr) - ctx.AddCloser(ctx.cmdw) + child.AddCloser(child.cmdr) + child.AddCloser(child.cmdw) // Create pipe used to signal continue to child process. - ctx.contr, ctx.contw, err = os.Pipe() + child.contr, child.contw, err = os.Pipe() if err != nil { - return nil, trace.Wrap(err) + child.Close() + return nil, nil, trace.Wrap(err) } - ctx.AddCloser(ctx.contr) - ctx.AddCloser(ctx.contw) + child.AddCloser(child.contr) + child.AddCloser(child.contw) - // gather environment variables from parent. - ctx.ImportParentEnv() + return ctx, child, nil +} - return ctx, nil +// Parent grants access to the connection-level context of which this +// is a subcontext. Useful for unambiguously accessing methods which +// this subcontext overrides (e.g. child.Parent().SetEnv(...)). +func (c *ServerContext) Parent() *sshutils.ConnectionContext { + return c.ConnectionContext } // ID returns ID of this context @@ -417,9 +414,9 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error { // update ctx with a session ID c.session, _ = findSession() if c.session == nil { - log.Debugf("Will create new session for SSH connection %v.", c.Conn.RemoteAddr()) + log.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr()) } else { - log.Debugf("Will join session %v for SSH connection %v.", c.session, c.Conn.RemoteAddr()) + log.Debugf("Will join session %v for SSH connection %v.", c.session, c.ServerConn.RemoteAddr()) } return nil @@ -448,24 +445,6 @@ func (c *ServerContext) AddCloser(closer io.Closer) { c.closers = append(c.closers, closer) } -// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent, -// or nil if no agent is available in this context. -func (c *ServerContext) GetAgent() agent.Agent { - if c.Parent == nil { - return nil - } - return c.Parent.GetAgent() -} - -// GetAgentChannel returns the channel over which communication with the agent occurs, -// or nil if no agent is available in this context. -func (c *ServerContext) GetAgentChannel() ssh.Channel { - if c.Parent == nil { - return nil - } - return c.Parent.GetAgentChannel() -} - // GetTerm returns a Terminal. func (c *ServerContext) GetTerm() Terminal { c.RLock() @@ -482,6 +461,16 @@ func (c *ServerContext) SetTerm(t Terminal) { c.term = t } +// VisitEnv grants visitor-style access to env variables. +func (c *ServerContext) VisitEnv(visit func(key, val string)) { + // visit the parent env first since locally defined variables + // effectively "override" parent defined variables. + c.Parent().VisitEnv(visit) + for key, val := range c.env { + visit(key, val) + } +} + // SetEnv sets a environment variable within this context. func (c *ServerContext) SetEnv(key, val string) { c.env[key] = val @@ -490,13 +479,10 @@ func (c *ServerContext) SetEnv(key, val string) { // GetEnv returns a environment variable within this context. func (c *ServerContext) GetEnv(key string) (string, bool) { val, ok := c.env[key] - return val, ok -} - -// ImportParentEnv is used to re-synchronize env vars after -// parent context has been updated. -func (c *ServerContext) ImportParentEnv() { - c.Parent.ExportEnv(c.env) + if ok { + return val, true + } + return c.Parent().GetEnv(key) } // takeClosers returns all resources that should be closed and sets the properties to null @@ -543,11 +529,11 @@ func (c *ServerContext) reportStats(conn utils.Stater) { events.SessionServerID: c.GetServer().HostUUID(), events.EventLogin: c.Identity.Login, events.EventUser: c.Identity.TeleportUser, - events.RemoteAddr: c.Conn.RemoteAddr().String(), + events.RemoteAddr: c.ServerConn.RemoteAddr().String(), events.EventIndex: events.SessionDataIndex, } if !c.srv.UseTunnel() { - eventFields[events.LocalAddr] = c.Conn.LocalAddr().String() + eventFields[events.LocalAddr] = c.ServerConn.LocalAddr().String() } if c.session != nil { eventFields[events.SessionEventID] = c.session.id @@ -564,7 +550,7 @@ func (c *ServerContext) reportStats(conn utils.Stater) { func (c *ServerContext) Close() error { // If the underlying connection is holding tracking information, report that // to the audit log at close. - if stats, ok := c.Connection.(*utils.TrackingConn); ok { + if stats, ok := c.NetConn.(*utils.TrackingConn); ok { defer c.reportStats(stats) } @@ -580,14 +566,10 @@ func (c *ServerContext) Close() error { return nil } -// CancelContext is a context associated with server context, -// closed whenever this server context is closed -func (c *ServerContext) CancelContext() context.Context { - return c.cancelContext -} - -// Cancel is a function that triggers closure -func (c *ServerContext) Cancel() context.CancelFunc { +// CancelFunc gets the context.CancelFunc associated with +// this context. Not a substitute for calling the +// ServerContext.Close method. +func (c *ServerContext) CancelFunc() context.CancelFunc { return c.cancel } @@ -638,7 +620,7 @@ func (c *ServerContext) ProxyPublicAddress() string { } func (c *ServerContext) String() string { - return fmt.Sprintf("ServerContext(%v->%v, user=%v, id=%v)", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), c.Conn.User(), c.id) + return fmt.Sprintf("ServerContext(%v->%v, user=%v, id=%v)", c.ServerConn.RemoteAddr(), c.ServerConn.LocalAddr(), c.ServerConn.User(), c.id) } // ExecCommand takes a *ServerContext and extracts the parts needed to create @@ -700,18 +682,18 @@ func (c *ServerContext) ExecCommand() (*execCommand, error) { func buildEnvironment(ctx *ServerContext) []string { var env []string - // Apply environment variables passed in from client. - for k, v := range ctx.env { - env = append(env, fmt.Sprintf("%s=%s", k, v)) - } + // gather all dynamically defined environment variables + ctx.VisitEnv(func(key, val string) { + env = append(env, fmt.Sprintf("%s=%s", key, val)) + }) // Parse the local and remote addresses to build SSH_CLIENT and // SSH_CONNECTION environment variables. - remoteHost, remotePort, err := net.SplitHostPort(ctx.Conn.RemoteAddr().String()) + remoteHost, remotePort, err := net.SplitHostPort(ctx.ServerConn.RemoteAddr().String()) if err != nil { log.Debugf("Failed to split remote address: %v.", err) } else { - localHost, localPort, err := net.SplitHostPort(ctx.Conn.LocalAddr().String()) + localHost, localPort, err := net.SplitHostPort(ctx.ServerConn.LocalAddr().String()) if err != nil { log.Debugf("Failed to split local address: %v.", err) } else { diff --git a/lib/srv/exec.go b/lib/srv/exec.go index e01e385d41107..451b288ea60e0 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -248,8 +248,8 @@ func (e *localExec) transformSecureCopy() error { } e.Command = fmt.Sprintf("%s scp --remote-addr=%s --local-addr=%s %v", teleportBin, - e.Ctx.Conn.RemoteAddr().String(), - e.Ctx.Conn.LocalAddr().String(), + e.Ctx.ServerConn.RemoteAddr().String(), + e.Ctx.ServerConn.LocalAddr().String(), strings.Join(args[1:], " ")) return nil @@ -367,8 +367,8 @@ func emitExecAuditEvent(ctx *ServerContext, cmd string, execErr error) { fields := events.EventFields{ events.EventUser: ctx.Identity.TeleportUser, events.EventLogin: ctx.Identity.Login, - events.LocalAddr: ctx.Conn.LocalAddr().String(), - events.RemoteAddr: ctx.Conn.RemoteAddr().String(), + events.LocalAddr: ctx.ServerConn.LocalAddr().String(), + events.RemoteAddr: ctx.ServerConn.RemoteAddr().String(), events.EventNamespace: ctx.srv.GetNamespace(), // Due to scp being inherently vulnerable to command injection, always // make sure the full command and exit code is recorded for accountability. diff --git a/lib/srv/exec_test.go b/lib/srv/exec_test.go index 5c4642bba6587..8a54b76b4350c 100644 --- a/lib/srv/exec_test.go +++ b/lib/srv/exec_test.go @@ -112,8 +112,9 @@ func (s *ExecSuite) SetUpSuite(c *check.C) { s.usr, _ = user.Current() s.ctx = &ServerContext{ - IsTestStub: true, - ClusterName: "localhost", + ConnectionContext: &sshutils.ConnectionContext{}, + IsTestStub: true, + ClusterName: "localhost", srv: &fakeServer{ accessPoint: s.a, auditLog: &fakeLog{}, @@ -123,7 +124,7 @@ func (s *ExecSuite) SetUpSuite(c *check.C) { s.ctx.Identity.Login = s.usr.Username s.ctx.session = &session{id: "xxx", term: &fakeTerminal{f: f}} s.ctx.Identity.TeleportUser = "galt" - s.ctx.Conn = &ssh.ServerConn{Conn: s} + s.ctx.ServerConn = &ssh.ServerConn{Conn: s} s.ctx.ExecRequest = &localExec{Ctx: s.ctx} s.ctx.request = &ssh.Request{ Type: sshutils.ExecRequest, @@ -257,7 +258,8 @@ func (s *ExecSuite) TestContinue(c *check.C) { // Create a fake context that will be used to configure a command that will // re-exec "ls". ctx := &ServerContext{ - IsTestStub: true, + ConnectionContext: &sshutils.ConnectionContext{}, + IsTestStub: true, srv: &fakeServer{ accessPoint: s.a, auditLog: &fakeLog{}, @@ -266,7 +268,7 @@ func (s *ExecSuite) TestContinue(c *check.C) { } ctx.Identity.Login = s.usr.Username ctx.Identity.TeleportUser = "galt" - ctx.Conn = &ssh.ServerConn{Conn: s} + ctx.ServerConn = &ssh.ServerConn{Conn: s} ctx.ExecRequest = &localExec{ Ctx: ctx, Command: lsPath, diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 0fa501b906fcf..3f80797c09293 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/proxy" "github.com/gravitational/trace" @@ -97,7 +98,7 @@ type Server struct { identityContext srv.IdentityContext // userAgent is the SSH user agent that was forwarded to the proxy. - userAgent agent.Agent + userAgent teleagent.Agent // hostCertificate is the SSH host certificate this in-memory server presents // to the client. @@ -133,11 +134,6 @@ type Server struct { sessionServer session.Service dataDir string - // closeContext and closeCancel are used to signal when the in-memory - // server is closing and all blocking goroutines should unblock. - closeContext context.Context - closeCancel context.CancelFunc - clock clockwork.Clock // hostUUID is the UUID of the underlying proxy that the forwarding server @@ -148,7 +144,7 @@ type Server struct { // ServerConfig is the configuration needed to create an instance of a Server. type ServerConfig struct { AuthClient auth.ClientI - UserAgent agent.Agent + UserAgent teleagent.Agent TargetConn net.Conn SrcAddr net.Addr DstAddr net.Addr @@ -287,10 +283,6 @@ func New(c ServerConfig) (*Server, error) { SessionRegistry: s.sessionRegistry, } - // Create a close context that is used internally to signal when the server - // is closing and for any blocking goroutines to unblock. - s.closeContext, s.closeCancel = context.WithCancel(context.Background()) - return s, nil } @@ -430,6 +422,7 @@ func (s *Server) Serve() { sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config) if err != nil { + s.userAgent.Close() s.targetConn.Close() s.clientConn.Close() s.serverConn.Close() @@ -439,11 +432,13 @@ func (s *Server) Serve() { } s.sconn = sconn - s.connectionContext = sshutils.NewConnectionContext(s.serverConn, s.sconn) + ctx := context.Background() + ctx, s.connectionContext = sshutils.NewConnectionContext(ctx, s.serverConn, s.sconn) // Take connection and extract identity information for the user from it. s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn) if err != nil { + s.userAgent.Close() s.targetConn.Close() s.clientConn.Close() s.serverConn.Close() @@ -461,6 +456,7 @@ func (s *Server) Serve() { s.rejectChannel(chans, err.Error()) sconn.Close() + s.userAgent.Close() s.targetConn.Close() s.clientConn.Close() s.serverConn.Close() @@ -479,16 +475,17 @@ func (s *Server) Serve() { }, Interval: clusterConfig.GetKeepAliveInterval(), MaxCount: clusterConfig.GetKeepAliveCountMax(), - CloseContext: s.closeContext, - CloseCancel: s.closeCancel, + CloseContext: ctx, + CloseCancel: func() { s.connectionContext.Close() }, }) - go s.handleConnection(chans, reqs) + go s.handleConnection(ctx, chans, reqs) } // Close will close all underlying connections that the forwarding server holds. func (s *Server) Close() error { conns := []io.Closer{ + s.userAgent, s.sconn, s.clientConn, s.serverConn, @@ -510,10 +507,6 @@ func (s *Server) Close() error { } } - // Signal to waiting goroutines that the server is closing (for example, - // the keep alive loop). - s.closeCancel() - return trace.NewAggregate(errs...) } @@ -554,7 +547,7 @@ func (s *Server) newRemoteClient(systemLogin string) (*ssh.Client, error) { return client, nil } -func (s *Server) handleConnection(chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { +func (s *Server) handleConnection(ctx context.Context, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { defer s.log.Debugf("Closing forwarding server connected to %v and releasing resources.", s.sconn.LocalAddr()) defer s.Close() @@ -571,10 +564,10 @@ func (s *Server) handleConnection(chans <-chan ssh.NewChannel, reqs <-chan *ssh. if newChannel == nil { return } - go s.handleChannel(newChannel) + go s.handleChannel(ctx, newChannel) // If the server is closing (either the heartbeat failed or Close() was // called, exit out of the connection handler loop. - case <-s.closeContext.Done(): + case <-ctx.Done(): return } } @@ -615,7 +608,7 @@ func (s *Server) handleGlobalRequest(req *ssh.Request) { } } -func (s *Server) handleChannel(nch ssh.NewChannel) { +func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) { channelType := nch.ChannelType() switch channelType { @@ -630,7 +623,7 @@ func (s *Server) handleChannel(nch ssh.NewChannel) { } return } - go s.handleSessionRequests(ch, requests) + go s.handleSessionRequests(ctx, ch, requests) // Channels of type "direct-tcpip" handles request for port forwarding. case teleport.ChanDirectTCPIP: req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData()) @@ -649,7 +642,7 @@ func (s *Server) handleChannel(nch ssh.NewChannel) { } return } - go s.handleDirectTCPIPRequest(ch, req) + go s.handleDirectTCPIPRequest(ctx, ch, req) default: if err := nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)); err != nil { s.log.Warnf("Failed to reject channel of unknown type: %v", err) @@ -658,43 +651,42 @@ func (s *Server) handleChannel(nch ssh.NewChannel) { } // handleDirectTCPIPRequest handles port forwarding requests. -func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTCPIPReq) { +func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext) + ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) if err != nil { - ctx.Errorf("Unable to create connection context: %v.", err) + scx.Errorf("Unable to create connection context: %v.", err) s.stderrWrite(ch, "Unable to create connection context.") return } - ctx.Connection = s.serverConn - ctx.RemoteClient = s.remoteClient - ctx.ChannelType = teleport.ChanDirectTCPIP - ctx.SrcAddr = fmt.Sprintf("%v:%d", req.Orig, req.OrigPort) - ctx.DstAddr = fmt.Sprintf("%v:%d", req.Host, req.Port) - defer ctx.Close() + scx.RemoteClient = s.remoteClient + scx.ChannelType = teleport.ChanDirectTCPIP + scx.SrcAddr = fmt.Sprintf("%v:%d", req.Orig, req.OrigPort) + scx.DstAddr = fmt.Sprintf("%v:%d", req.Host, req.Port) + defer scx.Close() // Check if the role allows port forwarding for this user. - err = s.authHandlers.CheckPortForward(ctx.DstAddr, ctx) + err = s.authHandlers.CheckPortForward(scx.DstAddr, scx) if err != nil { s.stderrWrite(ch, err.Error()) return } - s.log.Debugf("Opening direct-tcpip channel from %v to %v in context %v.", ctx.SrcAddr, ctx.DstAddr, ctx.ID()) - defer s.log.Debugf("Completing direct-tcpip request from %v to %v in context %v.", ctx.SrcAddr, ctx.DstAddr, ctx.ID()) + s.log.Debugf("Opening direct-tcpip channel from %v to %v in context %v.", scx.SrcAddr, scx.DstAddr, scx.ID()) + defer s.log.Debugf("Completing direct-tcpip request from %v to %v in context %v.", scx.SrcAddr, scx.DstAddr, scx.ID()) // Create "direct-tcpip" channel from the remote host to the target host. - conn, err := s.remoteClient.Dial("tcp", ctx.DstAddr) + conn, err := s.remoteClient.Dial("tcp", scx.DstAddr) if err != nil { - ctx.Infof("Failed to connect to: %v: %v", ctx.DstAddr, err) + scx.Infof("Failed to connect to: %v: %v", scx.DstAddr, err) return } defer conn.Close() // Emit a port forwarding audit event. s.EmitAuditEvent(events.PortForward, events.EventFields{ - events.PortForwardAddr: ctx.DstAddr, + events.PortForwardAddr: scx.DstAddr, events.PortForwardSuccess: true, events.EventLogin: s.identityContext.Login, events.EventUser: s.identityContext.TeleportUser, @@ -702,12 +694,13 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC events.RemoteAddr: s.sconn.RemoteAddr().String(), }) - wg := &sync.WaitGroup{} + var wg sync.WaitGroup + wch := make(chan struct{}) wg.Add(1) go func() { defer wg.Done() if _, err := io.Copy(ch, conn); err != nil { - ctx.Warningf("failed proxying data for port forwarding connection: %v", err) + scx.Warningf("failed proxying data for port forwarding connection: %v", err) } ch.Close() }() @@ -715,34 +708,41 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC go func() { defer wg.Done() if _, err := io.Copy(conn, ch); err != nil { - ctx.Warningf("failed proxying data for port forwarding connection: %v", err) + scx.Warningf("failed proxying data for port forwarding connection: %v", err) } conn.Close() }() - - wg.Wait() + // block on wg in separate goroutine so that we + // can select on wg and context cancellation. + go func() { + defer close(wch) + wg.Wait() + }() + select { + case <-wch: + case <-ctx.Done(): + } } // handleSessionRequests handles out of band session requests once the session // channel has been created this function's loop handles all the "exec", // "subsystem" and "shell" requests. -func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) { +func (s *Server) handleSessionRequests(ctx context.Context, ch ssh.Channel, in <-chan *ssh.Request) { // Create context for this channel. This context will be closed when the // session request is complete. // There is no need for the forwarding server to initiate disconnects, // based on teleport business logic, because this logic is already // done on the server's terminating side. - ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext) + ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) if err != nil { - ctx.Errorf("Unable to create connection context: %v.", err) + scx.Errorf("Unable to create connection context: %v.", err) s.stderrWrite(ch, "Unable to create connection context.") return } - ctx.Connection = s.serverConn - ctx.RemoteClient = s.remoteClient - ctx.AddCloser(ch) - ctx.ChannelType = teleport.ChanSession - defer ctx.Close() + scx.RemoteClient = s.remoteClient + scx.AddCloser(ch) + scx.ChannelType = teleport.ChanSession + defer scx.Close() // Create a "session" channel on the remote host. remoteSession, err := s.remoteClient.NewSession() @@ -750,57 +750,59 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) { s.stderrWrite(ch, err.Error()) return } - ctx.RemoteSession = remoteSession + scx.RemoteSession = remoteSession - s.log.Debugf("Opening session request to %v in context %v.", s.sconn.RemoteAddr(), ctx.ID()) - defer s.log.Debugf("Closing session request to %v in context %v.", s.sconn.RemoteAddr(), ctx.ID()) + s.log.Debugf("Opening session request to %v in context %v.", s.sconn.RemoteAddr(), scx.ID()) + defer s.log.Debugf("Closing session request to %v in context %v.", s.sconn.RemoteAddr(), scx.ID()) for { // Update the context with the session ID. - err := ctx.CreateOrJoinSession(s.sessionRegistry) + err := scx.CreateOrJoinSession(s.sessionRegistry) if err != nil { errorMessage := fmt.Sprintf("unable to update context: %v", err) - ctx.Errorf("%v", errorMessage) + scx.Errorf("%v", errorMessage) // Write the error to channel and close it. s.stderrWrite(ch, errorMessage) _, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: teleport.RemoteCommandFailure})) if err != nil { - ctx.Errorf("Failed to send exit status %v", errorMessage) + scx.Errorf("Failed to send exit status %v", errorMessage) } return } select { - case result := <-ctx.SubsystemResultCh: + case result := <-scx.SubsystemResultCh: // Subsystem has finished executing, close the channel and session. - ctx.Debugf("Subsystem execution result: %v", result.Err) + scx.Debugf("Subsystem execution result: %v", result.Err) return case req := <-in: if req == nil { // The client has closed or dropped the connection. - ctx.Debugf("Client %v disconnected", s.sconn.RemoteAddr()) + scx.Debugf("Client %v disconnected", s.sconn.RemoteAddr()) return } - if err := s.dispatch(ch, req, ctx); err != nil { + if err := s.dispatch(ch, req, scx); err != nil { s.replyError(ch, req, err) return } if req.WantReply { if err := req.Reply(true, nil); err != nil { - ctx.Errorf("failed sending OK response on %q request: %v", req.Type, err) + scx.Errorf("failed sending OK response on %q request: %v", req.Type, err) } } - case result := <-ctx.ExecResultCh: - ctx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code) + case result := <-scx.ExecResultCh: + scx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code) // The exec process has finished and delivered the execution result, send // the result back to the client, and close the session and channel. _, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: uint32(result.Code)})) if err != nil { - ctx.Infof("Failed to send exit status for %v: %v", result.Command, err) + scx.Infof("Failed to send exit status for %v: %v", result.Command, err) } + return + case <-ctx.Done(): return } } diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 31b75af7206c5..e5e52bb9d490c 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -27,7 +27,6 @@ import ( "sync" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" @@ -45,13 +44,13 @@ import ( // proxySubsys implements an SSH subsystem for proxying listening sockets from // remote hosts to a proxy client (AKA port mapping) type proxySubsys struct { - proxySubsysConfig - log *logrus.Entry - closeC chan struct{} - error error - closeOnce sync.Once - agent agent.Agent - agentChannel ssh.Channel + proxySubsysRequest + srv *Server + ctx *srv.ServerContext + log *logrus.Entry + closeC chan struct{} + error error + closeOnce sync.Once } // parseProxySubsys looks at the requested subsystem name and returns a fully configured @@ -62,7 +61,7 @@ type proxySubsys struct { // "proxy:@clustername" - Teleport request to connect to an auth server for cluster with name 'clustername' // "proxy:host:22@clustername" - Teleport request to connect to host:22 on cluster 'clustername' // "proxy:host:22@namespace@clustername" -func parseProxySubsys(request string, srv *Server, ctx *srv.ServerContext) (*proxySubsys, error) { +func parseProxySubsysRequest(request string) (proxySubsysRequest, error) { log.Debugf("parse_proxy_subsys(%q)", request) var ( clusterName string @@ -73,7 +72,7 @@ func parseProxySubsys(request string, srv *Server, ctx *srv.ServerContext) (*pro const prefix = "proxy:" // get rid of 'proxy:' prefix: if strings.Index(request, prefix) != 0 { - return nil, trace.BadParameter(paramMessage) + return proxySubsysRequest{}, trace.BadParameter(paramMessage) } requestBody := strings.TrimPrefix(request, prefix) namespace := defaults.Namespace @@ -81,100 +80,100 @@ func parseProxySubsys(request string, srv *Server, ctx *srv.ServerContext) (*pro parts := strings.Split(requestBody, "@") switch { case len(parts) == 0: // "proxy:" - return nil, trace.BadParameter(paramMessage) + return proxySubsysRequest{}, trace.BadParameter(paramMessage) case len(parts) == 1: // "proxy:host:22" targetHost, targetPort, err = utils.SplitHostPort(parts[0]) if err != nil { - return nil, trace.BadParameter(paramMessage) + return proxySubsysRequest{}, trace.BadParameter(paramMessage) } case len(parts) == 2: // "proxy:@clustername" or "proxy:host:22@clustername" if parts[0] != "" { targetHost, targetPort, err = utils.SplitHostPort(parts[0]) if err != nil { - return nil, trace.BadParameter(paramMessage) + return proxySubsysRequest{}, trace.BadParameter(paramMessage) } } clusterName = parts[1] if clusterName == "" && targetHost == "" { - return nil, trace.BadParameter("invalid format for proxy request: missing cluster name or target host in %q", request) + return proxySubsysRequest{}, trace.BadParameter("invalid format for proxy request: missing cluster name or target host in %q", request) } case len(parts) >= 3: // "proxy:host:22@namespace@clustername" clusterName = strings.Join(parts[2:], "@") namespace = parts[1] targetHost, targetPort, err = utils.SplitHostPort(parts[0]) if err != nil { - return nil, trace.BadParameter(paramMessage) + return proxySubsysRequest{}, trace.BadParameter(paramMessage) } } - return newProxySubsys(proxySubsysConfig{ + return proxySubsysRequest{ namespace: namespace, - srv: srv, - ctx: ctx, host: targetHost, port: targetPort, clusterName: clusterName, - }) + }, nil +} + +// parseProxySubsys decodes a proxy subsystem request and sets up a proxy subsystem instance. +// See parseProxySubsysRequest for details on the request format. +func parseProxySubsys(request string, srv *Server, ctx *srv.ServerContext) (*proxySubsys, error) { + req, err := parseProxySubsysRequest(request) + if err != nil { + return nil, trace.Wrap(err) + } + subsys, err := newProxySubsys(ctx, srv, req) + if err != nil { + return nil, trace.Wrap(err) + } + return subsys, nil } -// proxySubsysConfig is a proxy subsystem configuration -type proxySubsysConfig struct { +// proxySubsysRequest encodes proxy subsystem request parameters. +type proxySubsysRequest struct { namespace string host string port string clusterName string - srv *Server - ctx *srv.ServerContext } -func (p *proxySubsysConfig) String() string { +func (p *proxySubsysRequest) String() string { return fmt.Sprintf("host=%v, port=%v, cluster=%v", p.host, p.port, p.clusterName) } -// CheckAndSetDefaults checks and sets defaults -func (p *proxySubsysConfig) CheckAndSetDefaults() error { +// SetDefaults sets default values. +func (p *proxySubsysRequest) SetDefaults() { if p.namespace == "" { p.namespace = defaults.Namespace } - if p.srv == nil { - return trace.BadParameter("missing parameter server") - } - if p.ctx == nil { - return trace.BadParameter("missing parameter context") - } - if p.clusterName == "" && p.ctx.Identity.RouteToCluster != "" { - log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.", - p.ctx.Identity.TeleportUser, p.ctx.Identity.RouteToCluster, - ) - p.clusterName = p.ctx.Identity.RouteToCluster - } - if p.clusterName != "" && p.srv.proxyTun != nil { - _, err := p.srv.proxyTun.GetSite(p.clusterName) - if err != nil { - return trace.BadParameter("invalid format for proxy request: unknown cluster %q", p.clusterName) - } - } - - return nil } // newProxySubsys is a helper that creates a proxy subsystem from // a port forwarding request, used to implement ProxyJump feature in proxy // and reuse the code -func newProxySubsys(cfg proxySubsysConfig) (*proxySubsys, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) +func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) { + req.SetDefaults() + if req.clusterName == "" && ctx.Identity.RouteToCluster != "" { + log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.", + ctx.Identity.TeleportUser, ctx.Identity.RouteToCluster, + ) + req.clusterName = ctx.Identity.RouteToCluster + } + if req.clusterName != "" && srv.proxyTun != nil { + _, err := srv.proxyTun.GetSite(req.clusterName) + if err != nil { + return nil, trace.BadParameter("invalid format for proxy request: unknown cluster %q", req.clusterName) + } } - log.Debugf("newProxySubsys(%v).", cfg) + log.Debugf("newProxySubsys(%v).", req) return &proxySubsys{ - proxySubsysConfig: cfg, + proxySubsysRequest: req, + ctx: ctx, + srv: srv, log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentSubsystemProxy, trace.ComponentFields: map[string]string{}, }), - closeC: make(chan struct{}), - agent: cfg.ctx.GetAgent(), - agentChannel: cfg.ctx.GetAgentChannel(), + closeC: make(chan struct{}), }, nil } @@ -410,12 +409,12 @@ func (t *proxySubsys) proxyToHost( Addr: serverAddr, } conn, err := site.Dial(reversetunnel.DialParams{ - From: remoteAddr, - To: toAddr, - UserAgent: t.agent, - Address: t.host, - ServerID: serverID, - Principals: principals, + From: remoteAddr, + To: toAddr, + GetUserAgent: t.ctx.StartAgentChannel, + Address: t.host, + ServerID: serverID, + Principals: principals, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index a3d2d305f64c0..ee904aca07b2a 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -17,6 +17,7 @@ limitations under the License. package regular import ( + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv" "gopkg.in/check.v1" @@ -35,52 +36,54 @@ func (s *ProxyTestSuite) SetUpSuite(c *check.C) { } func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { - ctx := &srv.ServerContext{} - - // proxy request for a host:port - subsys, err := parseProxySubsys("proxy:host:22", s.srv, ctx) - c.Assert(err, check.IsNil) - c.Assert(subsys, check.NotNil) - c.Assert(subsys.srv, check.Equals, s.srv) - c.Assert(subsys.host, check.Equals, "host") - c.Assert(subsys.port, check.Equals, "22") - c.Assert(subsys.clusterName, check.Equals, "") - - // similar request, just with '@' at the end (missing site) - subsys, err = parseProxySubsys("proxy:host:22@", s.srv, ctx) - c.Assert(err, check.IsNil) - c.Assert(subsys.srv, check.Equals, s.srv) - c.Assert(subsys.host, check.Equals, "host") - c.Assert(subsys.port, check.Equals, "22") - c.Assert(subsys.clusterName, check.Equals, "") - // proxy request for just the sitename - subsys, err = parseProxySubsys("proxy:@moon", s.srv, ctx) - c.Assert(err, check.IsNil) - c.Assert(subsys, check.NotNil) - c.Assert(subsys.srv, check.Equals, s.srv) - c.Assert(subsys.host, check.Equals, "") - c.Assert(subsys.port, check.Equals, "") - c.Assert(subsys.clusterName, check.Equals, "moon") - - // proxy request for the host:port@sitename - subsys, err = parseProxySubsys("proxy:station:100@moon", s.srv, ctx) - c.Assert(err, check.IsNil) - c.Assert(subsys, check.NotNil) - c.Assert(subsys.srv, check.Equals, s.srv) - c.Assert(subsys.host, check.Equals, "station") - c.Assert(subsys.port, check.Equals, "100") - c.Assert(subsys.clusterName, check.Equals, "moon") + tt := []struct { + req, host, port, cluster, namespace string + }{ + { // proxy request for a host:port + req: "proxy:host:22", + host: "host", + port: "22", + }, + { // similar request, just with '@' at the end (missing site) + req: "proxy:host:22@", + host: "host", + port: "22", + }, + { // proxy request for just the sitename + req: "proxy:@moon", + cluster: "moon", + }, + { // proxy request for the host:port@sitename + req: "proxy:station:100@moon", + host: "station", + port: "100", + cluster: "moon", + }, + { // proxy request for the host:port@namespace@cluster + req: "proxy:station:100@system@moon", + host: "station", + port: "100", + cluster: "moon", + namespace: "system", + }, + } - // proxy request for the host:port@namespace@cluster - subsys, err = parseProxySubsys("proxy:station:100@system@moon", s.srv, ctx) - c.Assert(err, check.IsNil) - c.Assert(subsys, check.NotNil) - c.Assert(subsys.srv, check.Equals, s.srv) - c.Assert(subsys.host, check.Equals, "station") - c.Assert(subsys.port, check.Equals, "100") - c.Assert(subsys.clusterName, check.Equals, "moon") - c.Assert(subsys.namespace, check.Equals, "system") + for i, t := range tt { + if t.namespace == "" { + // test cases without a defined namespace are testing for + // the presence of the default namespace; namespace should + // never actually be empty. + t.namespace = defaults.Namespace + } + cmt := check.Commentf("Test case %d: %+v", i, t) + req, err := parseProxySubsysRequest(t.req) + c.Assert(err, check.IsNil, cmt) + c.Assert(req.host, check.Equals, t.host, cmt) + c.Assert(req.port, check.Equals, t.port, cmt) + c.Assert(req.clusterName, check.Equals, t.cluster, cmt) + c.Assert(req.namespace, check.Equals, t.namespace, cmt) + } } func (s *ProxyTestSuite) TestParseBadRequests(c *check.C) { diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 069b04f5b0cde..69b374b47c74f 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -34,7 +34,6 @@ import ( "time" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" @@ -775,20 +774,18 @@ func (s *Server) serveAgent(ctx *srv.ServerContext) error { return trace.ConvertSystemError(err) } - // start an agent on a unix socket - agentServer := &teleagent.AgentServer{Agent: ctx.Parent.GetAgent()} + // start an agent server on a unix socket. each incoming connection + // will result in a separate agent request. + agentServer := teleagent.NewServer(ctx.Parent().StartAgentChannel) err = agentServer.ListenUnixSocket(socketPath, uid, gid, 0600) if err != nil { return trace.Wrap(err) } - ctx.Parent.SetEnv(teleport.SSHAuthSock, socketPath) - ctx.Parent.SetEnv(teleport.SSHAgentPID, fmt.Sprintf("%v", pid)) - ctx.Parent.AddCloser(agentServer) - ctx.Parent.AddCloser(dirCloser) - // ensure that SSHAuthSock and SSHAgentPID are imported into - // the current child context. - ctx.ImportParentEnv() - ctx.Debugf("Opened agent channel for Teleport user %v and socket %v.", ctx.Identity.TeleportUser, socketPath) + ctx.Parent().SetEnv(teleport.SSHAuthSock, socketPath) + ctx.Parent().SetEnv(teleport.SSHAgentPID, fmt.Sprintf("%v", pid)) + ctx.Parent().AddCloser(agentServer) + ctx.Parent().AddCloser(dirCloser) + ctx.Debugf("Starting agent server for Teleport user %v and socket %v.", ctx.Identity.TeleportUser, socketPath) go func() { if err := agentServer.Serve(); err != nil { ctx.Errorf("agent server for user %q stopped: %v", ctx.Identity.TeleportUser, err) @@ -840,7 +837,7 @@ func (s *Server) HandleRequest(r *ssh.Request) { } // HandleNewChan is called when new channel is opened -func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { +func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { identityContext, err := s.authHandlers.CreateIdentityContext(ccx.ServerConn) if err != nil { rejectChannel(nch, ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err)) @@ -865,7 +862,7 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleProxyJump(ccx, identityContext, ch, *req) + go s.handleProxyJump(ctx, ccx, identityContext, ch, *req) return // Channels of type "session" handle requests that are involved in running // commands on a server. In the case of proxy mode subsystem and agent @@ -877,7 +874,7 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(ccx, identityContext, ch, requests) + go s.handleSessionRequests(ctx, ccx, identityContext, ch, requests) return default: rejectChannel(nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) @@ -895,7 +892,7 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(ccx, identityContext, ch, requests) + go s.handleSessionRequests(ctx, ccx, identityContext, ch, requests) // Channels of type "direct-tcpip" handles request for port forwarding. case teleport.ChanDirectTCPIP: req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData()) @@ -910,43 +907,43 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleDirectTCPIPRequest(ccx, identityContext, ch, req) + go s.handleDirectTCPIPRequest(ctx, ccx, identityContext, ch, req) default: rejectChannel(nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) } } // handleDirectTCPIPRequest handles port forwarding requests. -func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) { +func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - ctx, err := srv.NewServerContext(ccx, s, identityContext) + ctx, scx, err := srv.NewServerContext(ctx, ccx, s, identityContext) if err != nil { log.Errorf("Unable to create connection context: %v.", err) writeStderr(channel, "Unable to create connection context.") return } - ctx.IsTestStub = s.isTestStub - ctx.AddCloser(channel) - ctx.ChannelType = teleport.ChanDirectTCPIP - ctx.SrcAddr = net.JoinHostPort(req.Orig, strconv.Itoa(int(req.OrigPort))) - ctx.DstAddr = net.JoinHostPort(req.Host, strconv.Itoa(int(req.Port))) - defer ctx.Close() + scx.IsTestStub = s.isTestStub + scx.AddCloser(channel) + scx.ChannelType = teleport.ChanDirectTCPIP + scx.SrcAddr = net.JoinHostPort(req.Orig, strconv.Itoa(int(req.OrigPort))) + scx.DstAddr = net.JoinHostPort(req.Host, strconv.Itoa(int(req.Port))) + defer scx.Close() // Check if the role allows port forwarding for this user. - err = s.authHandlers.CheckPortForward(ctx.DstAddr, ctx) + err = s.authHandlers.CheckPortForward(scx.DstAddr, scx) if err != nil { writeStderr(channel, err.Error()) return } - ctx.Debugf("Opening direct-tcpip channel from %v to %v.", ctx.SrcAddr, ctx.DstAddr) - defer ctx.Debugf("Closing direct-tcpip channel from %v to %v.", ctx.SrcAddr, ctx.DstAddr) + scx.Debugf("Opening direct-tcpip channel from %v to %v.", scx.SrcAddr, scx.DstAddr) + defer scx.Debugf("Closing direct-tcpip channel from %v to %v.", scx.SrcAddr, scx.DstAddr) // Create command to re-exec Teleport which will perform a net.Dial. The // reason it's not done directly is because the PAM stack needs to be called // from another process. - cmd, err := srv.ConfigureCommand(ctx) + cmd, err := srv.ConfigureCommand(scx) if err != nil { writeStderr(channel, err.Error()) } @@ -995,6 +992,8 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident if err != nil && err != io.EOF { log.Warnf("Connection problem in \"direct-tcpip\" channel: %v %T.", trace.DebugReport(err), err) } + case <-ctx.Done(): + break case <-s.ctx.Done(): break } @@ -1007,36 +1006,31 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident // Emit a port forwarding event. s.EmitAuditEvent(events.PortForward, events.EventFields{ - events.PortForwardAddr: ctx.DstAddr, + events.PortForwardAddr: scx.DstAddr, events.PortForwardSuccess: true, - events.EventLogin: ctx.Identity.Login, - events.EventUser: ctx.Identity.TeleportUser, - events.LocalAddr: ctx.Conn.LocalAddr().String(), - events.RemoteAddr: ctx.Conn.RemoteAddr().String(), + events.EventLogin: scx.Identity.Login, + events.EventUser: scx.Identity.TeleportUser, + events.LocalAddr: scx.ServerConn.LocalAddr().String(), + events.RemoteAddr: scx.ServerConn.RemoteAddr().String(), }) } // handleSessionRequests handles out of band session requests once the session // channel has been created this function's loop handles all the "exec", // "subsystem" and "shell" requests. -func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) { +func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) { // Create context for this channel. This context will be closed when the // session request is complete. - ctx, err := srv.NewServerContext(ccx, s, identityContext) + ctx, scx, err := srv.NewServerContext(ctx, ccx, s, identityContext) if err != nil { log.Errorf("Unable to create connection context: %v.", err) writeStderr(ch, "Unable to create connection context.") return } - ctx.IsTestStub = s.isTestStub - ctx.AddCloser(ch) - ctx.ChannelType = teleport.ChanSession - defer ctx.Close() - - // Create a close context used to signal between the server and the - // keep-alive loop when to close the connection (from either side). - closeContext, closeCancel := context.WithCancel(context.Background()) - defer closeCancel() + scx.IsTestStub = s.isTestStub + scx.AddCloser(ch) + scx.ChannelType = teleport.ChanSession + defer scx.Close() clusterConfig, err := s.GetAccessPoint().GetClusterConfig() if err != nil { @@ -1050,44 +1044,44 @@ func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identity // closeContext which signals the server to shutdown. go srv.StartKeepAliveLoop(srv.KeepAliveParams{ Conns: []srv.RequestSender{ - ctx.Conn, + scx.ServerConn, }, Interval: clusterConfig.GetKeepAliveInterval(), MaxCount: clusterConfig.GetKeepAliveCountMax(), - CloseContext: closeContext, - CloseCancel: closeCancel, + CloseContext: ctx, + CloseCancel: scx.CancelFunc(), }) for { - // update ctx with the session ID: + // update scx with the session ID: if !s.proxyMode { - err := ctx.CreateOrJoinSession(s.reg) + err := scx.CreateOrJoinSession(s.reg) if err != nil { errorMessage := fmt.Sprintf("unable to update context: %v", err) - ctx.Errorf("Unable to update context: %v.", errorMessage) + scx.Errorf("Unable to update context: %v.", errorMessage) // write the error to channel and close it writeStderr(ch, errorMessage) _, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: teleport.RemoteCommandFailure})) if err != nil { - ctx.Errorf("Failed to send exit status %v.", errorMessage) + scx.Errorf("Failed to send exit status %v.", errorMessage) } return } } select { - case creq := <-ctx.SubsystemResultCh: + case creq := <-scx.SubsystemResultCh: // this means that subsystem has finished executing and // want us to close session and the channel - ctx.Debugf("Close session request: %v.", creq.Err) + scx.Debugf("Close session request: %v.", creq.Err) return case req := <-in: if req == nil { // this will happen when the client closes/drops the connection - ctx.Debugf("Client %v disconnected.", ctx.Conn.RemoteAddr()) + scx.Debugf("Client %v disconnected.", scx.ServerConn.RemoteAddr()) return } - if err := s.dispatch(ch, req, ctx); err != nil { + if err := s.dispatch(ch, req, scx); err != nil { s.replyError(ch, req, err) return } @@ -1096,19 +1090,19 @@ func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identity log.Warnf("Failed to reply to %q request: %v", req.Type, err) } } - case result := <-ctx.ExecResultCh: - ctx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code) + case result := <-scx.ExecResultCh: + scx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code) // The exec process has finished and delivered the execution result, send // the result back to the client, and close the session and channel. _, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: uint32(result.Code)})) if err != nil { - ctx.Infof("Failed to send exit status for %v: %v", result.Command, err) + scx.Infof("Failed to send exit status for %v: %v", result.Command, err) } return - case <-closeContext.Done(): - log.Debugf("Closing session due to missed heartbeat.") + case <-ctx.Done(): + log.Debugf("Closing session due to cancellation.") return } } @@ -1188,14 +1182,9 @@ func (s *Server) handleAgentForwardNode(req *ssh.Request, ctx *srv.ServerContext return trace.Wrap(err) } - // open a channel to the client where the client will serve an agent - authChannel, _, err := ctx.Conn.OpenChannel(sshutils.AuthAgentRequest, nil) - if err != nil { - return trace.Wrap(err) - } - - // save the agent in the context so it can be used later - ctx.Parent.SetAgent(agent.NewClient(authChannel), authChannel) + // Enable agent forwarding for the broader connection-level + // context. + ctx.Parent().SetForwardAgent(true) // serve an agent on a unix socket on this node err = s.serveAgent(ctx) @@ -1224,16 +1213,9 @@ func (s *Server) handleAgentForwardProxy(req *ssh.Request, ctx *srv.ServerContex return trace.Wrap(err) } - // Open a channel to the client where the client will serve an agent. - authChannel, _, err := ctx.Conn.OpenChannel(sshutils.AuthAgentRequest, nil) - if err != nil { - return trace.Wrap(err) - } - - // Save the agent so it can be used when making a proxy subsystem request - // later. It will also be used when building a remote connection to the - // target node. - ctx.Parent.SetAgent(agent.NewClient(authChannel), authChannel) + // Enable agent forwarding for the broader connection-level + // context. + ctx.Parent().SetForwardAgent(true) return nil } @@ -1247,7 +1229,7 @@ func (s *Server) handleSubsystem(ch ssh.Channel, req *ssh.Request, ctx *srv.Serv ctx.Debugf("Subsystem request: %v.", sb) // starting subsystem is blocking to the client, // while collecting its result and waiting is not blocking - if err := sb.Start(ctx.Conn, ch, req, ctx); err != nil { + if err := sb.Start(ctx.ServerConn, ch, req, ctx); err != nil { ctx.Warnf("Subsystem request %v failed: %v.", sb, err) ctx.SendSubsystemResult(srv.SubsystemResult{Err: trace.Wrap(err)}) return trace.Wrap(err) @@ -1328,18 +1310,18 @@ func (s *Server) handleVersionRequest(req *ssh.Request) { } // handleProxyJump handles ProxyJump request that is executed via direct tcp-ip dial on the proxy -func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, req sshutils.DirectTCPIPReq) { +func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, req sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when the // session request is complete. - ctx, err := srv.NewServerContext(ccx, s, identityContext) + ctx, scx, err := srv.NewServerContext(ctx, ccx, s, identityContext) if err != nil { log.Errorf("Unable to create connection context: %v.", err) writeStderr(ch, "Unable to create connection context.") return } - ctx.IsTestStub = s.isTestStub - ctx.AddCloser(ch) - defer ctx.Close() + scx.IsTestStub = s.isTestStub + scx.AddCloser(ch) + defer scx.Close() clusterConfig, err := s.GetAccessPoint().GetClusterConfig() if err != nil { @@ -1374,7 +1356,7 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex // which is a hack, but the only way we can think of making it work, // ideas are appreciated. if clusterConfig.GetSessionRecording() == services.RecordAtProxy { - err = s.handleAgentForwardProxy(&ssh.Request{}, ctx) + err = s.handleAgentForwardProxy(&ssh.Request{}, scx) if err != nil { log.Warningf("Failed to request agent in recording mode: %v", err) writeStderr(ch, "Failed to request agent") @@ -1387,26 +1369,17 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex // closeContext which signals the server to shutdown. go srv.StartKeepAliveLoop(srv.KeepAliveParams{ Conns: []srv.RequestSender{ - ctx.Conn, + scx.ServerConn, }, Interval: clusterConfig.GetKeepAliveInterval(), MaxCount: clusterConfig.GetKeepAliveCountMax(), - CloseContext: ctx.CancelContext(), - // Looks liks this is this the best way to signal - // close to the proxy subsystem, as it will close - // the channel that proxy subsystem is blocked on. - CloseCancel: func() { - if err := ctx.Close(); err != nil { - log.Warningf("Failed to close: %v.", err) - } - }, + CloseContext: ctx, + CloseCancel: scx.CancelFunc(), }) - subsys, err := newProxySubsys(proxySubsysConfig{ + subsys, err := newProxySubsys(scx, s, proxySubsysRequest{ host: req.Host, port: fmt.Sprintf("%v", req.Port), - srv: s, - ctx: ctx, }) if err != nil { log.Errorf("Unable instantiate proxy subsystem: %v.", err) @@ -1414,16 +1387,23 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex return } - if err := subsys.Start(ctx.Conn, ch, &ssh.Request{}, ctx); err != nil { + if err := subsys.Start(scx.ServerConn, ch, &ssh.Request{}, scx); err != nil { log.Errorf("Unable to start proxy subsystem: %v.", err) writeStderr(ch, "Unable to start proxy subsystem.") return } - if err := subsys.Wait(); err != nil { - log.Errorf("Proxy subsystem failed: %v.", err) - writeStderr(ch, "Proxy subsystem failed.") - return + wch := make(chan struct{}) + go func() { + defer close(wch) + if err := subsys.Wait(); err != nil { + log.Errorf("Proxy subsystem failed: %v.", err) + writeStderr(ch, "Proxy subsystem failed.") + } + }() + select { + case <-wch: + case <-ctx.Done(): } } diff --git a/lib/srv/sess.go b/lib/srv/sess.go index d8d3827ab8a75..7dcf0748e84bf 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -128,12 +128,12 @@ func (s *SessionRegistry) emitSessionJoinEvent(ctx *ServerContext) { events.EventNamespace: s.srv.GetNamespace(), events.EventLogin: ctx.Identity.Login, events.EventUser: ctx.Identity.TeleportUser, - events.RemoteAddr: ctx.Conn.RemoteAddr().String(), + events.RemoteAddr: ctx.ServerConn.RemoteAddr().String(), events.SessionServerID: ctx.srv.HostUUID(), } // Local address only makes sense for non-tunnel nodes. if !ctx.srv.UseTunnel() { - sessionJoinEvent[events.LocalAddr] = ctx.Conn.LocalAddr().String() + sessionJoinEvent[events.LocalAddr] = ctx.ServerConn.LocalAddr().String() } // Emit session join event to Audit Log. @@ -672,14 +672,14 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error { events.SessionServerID: ctx.srv.HostUUID(), events.EventLogin: ctx.Identity.Login, events.EventUser: ctx.Identity.TeleportUser, - events.RemoteAddr: ctx.Conn.RemoteAddr().String(), + events.RemoteAddr: ctx.ServerConn.RemoteAddr().String(), events.TerminalSize: params.Serialize(), events.SessionServerHostname: ctx.srv.GetInfo().GetHostname(), events.SessionServerLabels: ctx.srv.GetInfo().GetAllLabels(), } // Local address only makes sense for non-tunnel nodes. if !ctx.srv.UseTunnel() { - eventFields[events.LocalAddr] = ctx.Conn.LocalAddr().String() + eventFields[events.LocalAddr] = ctx.ServerConn.LocalAddr().String() } s.emitAuditEvent(events.SessionStart, eventFields) @@ -787,13 +787,13 @@ func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error { events.SessionServerID: ctx.srv.HostUUID(), events.EventLogin: ctx.Identity.Login, events.EventUser: ctx.Identity.TeleportUser, - events.RemoteAddr: ctx.Conn.RemoteAddr().String(), + events.RemoteAddr: ctx.ServerConn.RemoteAddr().String(), events.SessionServerHostname: ctx.srv.GetInfo().GetHostname(), events.SessionServerLabels: ctx.srv.GetInfo().GetAllLabels(), } // Local address only makes sense for non-tunnel nodes. if !ctx.srv.UseTunnel() { - eventFields[events.LocalAddr] = ctx.Conn.LocalAddr().String() + eventFields[events.LocalAddr] = ctx.ServerConn.LocalAddr().String() } s.emitAuditEvent(events.SessionStart, eventFields) @@ -1188,12 +1188,12 @@ func newParty(s *session, ch ssh.Channel, ctx *ServerContext) *party { user: ctx.Identity.TeleportUser, login: ctx.Identity.Login, serverID: s.registry.srv.ID(), - site: ctx.Conn.RemoteAddr().String(), + site: ctx.ServerConn.RemoteAddr().String(), id: rsession.NewID(), ch: ch, ctx: ctx, s: s, - sconn: ctx.Conn, + sconn: ctx.ServerConn, termSizeC: make(chan []byte, 5), closeC: make(chan bool), } diff --git a/lib/sshutils/ctx.go b/lib/sshutils/ctx.go index 1578aac0d8989..5764b17f5b1a7 100644 --- a/lib/sshutils/ctx.go +++ b/lib/sshutils/ctx.go @@ -17,10 +17,13 @@ limitations under the License. package sshutils import ( + "context" "io" "net" "sync" + "github.com/gravitational/teleport/lib/teleagent" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -42,23 +45,71 @@ type ConnectionContext struct { // set for all channels. env map[string]string - // agent is a client to remote SSH agent. - agent agent.Agent + // forwardAgent indicates that agent forwarding has + // been requested for this connection. + forwardAgent bool - // agentCh is SSH channel using SSH agent protocol. - agentChannel ssh.Channel // closers is a list of io.Closer that will be called when session closes // this is handy as sometimes client closes session, in this case resources // will be properly closed and deallocated, otherwise they could be kept hanging. closers []io.Closer + + // closed indicates that closers have been run. + closed bool + + // cancel cancels the context.Context scope associated with this ConnectionContext. + cancel context.CancelFunc } -// NewConnectionContext creates a new ConnectionContext instance. -func NewConnectionContext(nconn net.Conn, sconn *ssh.ServerConn) *ConnectionContext { - return &ConnectionContext{ +// NewConnectionContext creates a new ConnectionContext and a child context.Context +// instance which will be canceled when the ConnectionContext is closed. +func NewConnectionContext(ctx context.Context, nconn net.Conn, sconn *ssh.ServerConn) (context.Context, *ConnectionContext) { + ctx, cancel := context.WithCancel(ctx) + return ctx, &ConnectionContext{ NetConn: nconn, ServerConn: sconn, env: make(map[string]string), + cancel: cancel, + } +} + +// agentChannel implements the extended teleteleagent.Agent interface, +// allowing the underlying ssh.Channel to be closed when the agent +// is no longer needed. +type agentChannel struct { + agent.Agent + ch ssh.Channel +} + +func (a *agentChannel) Close() error { + return a.ch.Close() +} + +// StartAgentChannel sets up a new agent forwarding channel against this connection. The channel +// is automatically closed when either ConnectionContext, or the supplied context.Context +// gets canceled. +func (c *ConnectionContext) StartAgentChannel() (teleagent.Agent, error) { + // refuse to start an agent if forwardAgent has not yet been set. + if !c.GetForwardAgent() { + return nil, trace.AccessDenied("agent forwarding not requested or not authorized") + } + // open a agent channel to client + ch, _, err := c.ServerConn.OpenChannel(AuthAgentRequest, nil) + if err != nil { + return nil, trace.Wrap(err) + } + return &agentChannel{ + Agent: agent.NewClient(ch), + ch: ch, + }, nil +} + +// VisitEnv grants visitor-style access to env variables. +func (c *ConnectionContext) VisitEnv(visit func(key, val string)) { + c.mu.Lock() + defer c.mu.Unlock() + for key, val := range c.env { + visit(key, val) } } @@ -77,41 +128,19 @@ func (c *ConnectionContext) GetEnv(key string) (string, bool) { return val, ok } -// ExportEnv writes all env vars to supplied map (used to configure -// env of child contexts). -func (c *ConnectionContext) ExportEnv(m map[string]string) { - c.mu.RLock() - defer c.mu.RUnlock() - for key, val := range c.env { - m[key] = val - } -} - -// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent, -// or nil if no agent is available in this context. -func (c *ConnectionContext) GetAgent() agent.Agent { - c.mu.RLock() - defer c.mu.RUnlock() - return c.agent -} - -// GetAgentChannel returns the channel over which communication with the agent occurs, -// or nil if no agent is available in this context. -func (c *ConnectionContext) GetAgentChannel() ssh.Channel { - c.mu.RLock() - defer c.mu.RUnlock() - return c.agentChannel +// SetForwardAgent configures this context to support agent forwarding. +// Must not be set until agent forwarding is explicitly requested. +func (c *ConnectionContext) SetForwardAgent(forwardAgent bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.forwardAgent = forwardAgent } -// SetAgent sets the agent and channel over which communication with the agent occurs. -func (c *ConnectionContext) SetAgent(a agent.Agent, channel ssh.Channel) { +// GetForwardAgent loads the forwardAgent flag with lock. +func (c *ConnectionContext) GetForwardAgent() bool { c.mu.Lock() defer c.mu.Unlock() - if c.agentChannel != nil { - c.agentChannel.Close() - } - c.agentChannel = channel - c.agent = a + return c.forwardAgent } // AddCloser adds any closer in ctx that will be called @@ -119,6 +148,12 @@ func (c *ConnectionContext) SetAgent(a agent.Agent, channel ssh.Channel) { func (c *ConnectionContext) AddCloser(closer io.Closer) { c.mu.Lock() defer c.mu.Unlock() + // if context was already closed, run the closer immediately + // in the background. + if c.closed { + go closer.Close() + return + } c.closers = append(c.closers, closer) } @@ -131,10 +166,8 @@ func (c *ConnectionContext) takeClosers() []io.Closer { closers := c.closers c.closers = nil - if c.agentChannel != nil { - closers = append(closers, c.agentChannel) - c.agentChannel = nil - } + c.closed = true + return closers } @@ -142,6 +175,8 @@ func (c *ConnectionContext) takeClosers() []io.Closer { func (c *ConnectionContext) Close() error { var errs []error + c.cancel() + closers := c.takeClosers() for _, cl := range closers { diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index a46a1d022b2dd..ae446bbb658cd 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -437,7 +437,11 @@ func (s *Server) HandleConnection(conn net.Conn) { defer keepAliveTick.Stop() keepAlivePayload := [8]byte{0} - ccx := NewConnectionContext(wconn, sconn) + // NOTE: we deliberately don't use s.closeContext here because the server's + // closeContext field is used to trigger starvation on cancellation by halting + // the acceptance of new connections; it is not intended to halt in-progress + // connection handling, and is therefore orthogonal to the role of ConnectionContext. + ctx, ccx := NewConnectionContext(context.Background(), wconn, sconn) defer ccx.Close() for { @@ -458,7 +462,7 @@ func (s *Server) HandleConnection(conn net.Conn) { connClosed() return } - go s.newChanHandler.HandleNewChan(ccx, nch) + go s.newChanHandler.HandleNewChan(ctx, ccx, nch) // send keepalive pings to the clients case <-keepAliveTick.C: const wantReply = true @@ -466,6 +470,8 @@ func (s *Server) HandleConnection(conn net.Conn) { if err != nil { log.Errorf("Failed sending keepalive request: %v", err) } + case <-ctx.Done(): + log.Debugf("Connection context canceled: %v -> %v", conn.RemoteAddr(), conn.LocalAddr()) } } } @@ -475,13 +481,13 @@ type RequestHandler interface { } type NewChanHandler interface { - HandleNewChan(*ConnectionContext, ssh.NewChannel) + HandleNewChan(context.Context, *ConnectionContext, ssh.NewChannel) } -type NewChanHandlerFunc func(*ConnectionContext, ssh.NewChannel) +type NewChanHandlerFunc func(context.Context, *ConnectionContext, ssh.NewChannel) -func (f NewChanHandlerFunc) HandleNewChan(ccx *ConnectionContext, ch ssh.NewChannel) { - f(ccx, ch) +func (f NewChanHandlerFunc) HandleNewChan(ctx context.Context, ccx *ConnectionContext, ch ssh.NewChannel) { + f(ctx, ccx, ch) } type AuthMethods struct { diff --git a/lib/sshutils/server_test.go b/lib/sshutils/server_test.go index dc2c83fa57afd..cb59cb66a1739 100644 --- a/lib/sshutils/server_test.go +++ b/lib/sshutils/server_test.go @@ -51,7 +51,7 @@ func (s *ServerSuite) SetUpSuite(c *check.C) { func (s *ServerSuite) TestStartStop(c *check.C) { called := false - fn := NewChanHandlerFunc(func(_ *ConnectionContext, nch ssh.NewChannel) { + fn := NewChanHandlerFunc(func(_ context.Context, _ *ConnectionContext, nch ssh.NewChannel) { called = true err := nch.Reject(ssh.Prohibited, "nothing to see here") c.Assert(err, check.IsNil) @@ -88,7 +88,7 @@ func (s *ServerSuite) TestStartStop(c *check.C) { // TestShutdown tests graceul shutdown feature func (s *ServerSuite) TestShutdown(c *check.C) { closeContext, cancel := context.WithCancel(context.TODO()) - fn := NewChanHandlerFunc(func(ccx *ConnectionContext, nch ssh.NewChannel) { + fn := NewChanHandlerFunc(func(_ context.Context, ccx *ConnectionContext, nch ssh.NewChannel) { ch, _, err := nch.Accept() c.Assert(err, check.IsNil) defer ch.Close() @@ -138,7 +138,7 @@ func (s *ServerSuite) TestShutdown(c *check.C) { } func (s *ServerSuite) TestConfigureCiphers(c *check.C) { - fn := NewChanHandlerFunc(func(_ *ConnectionContext, nch ssh.NewChannel) { + fn := NewChanHandlerFunc(func(_ context.Context, _ *ConnectionContext, nch ssh.NewChannel) { err := nch.Reject(ssh.Prohibited, "nothing to see here") c.Assert(err, check.IsNil) }) @@ -185,7 +185,7 @@ func (s *ServerSuite) TestHostSignerFIPS(c *check.C) { _, ellipticSigner, err := utils.CreateEllipticCertificate("foo", ssh.HostCert) c.Assert(err, check.IsNil) - newChanHandler := NewChanHandlerFunc(func(_ *ConnectionContext, nch ssh.NewChannel) { + newChanHandler := NewChanHandlerFunc(func(_ context.Context, _ *ConnectionContext, nch ssh.NewChannel) { err := nch.Reject(ssh.Prohibited, "nothing to see here") c.Assert(err, check.IsNil) }) diff --git a/lib/teleagent/agent.go b/lib/teleagent/agent.go index 7588c3718e8ff..9f04171afe5b7 100644 --- a/lib/teleagent/agent.go +++ b/lib/teleagent/agent.go @@ -14,16 +14,42 @@ import ( "golang.org/x/crypto/ssh/agent" ) +// Agent extends the agent.Agent interface. +// APIs which accept this interface promise to +// call `Close()` when they are done using the +// supplied agent. +type Agent interface { + agent.Agent + io.Closer +} + +// nopCloser wraps an agent.Agent in the extended +// Agent interface by adding a NOP closer. +type nopCloser struct { + agent.Agent +} + +func (n nopCloser) Close() error { return nil } + +// NopCloser wraps an agent.Agent with a NOP closer, allowing it +// to be passed to APIs which expect the extended agent interface. +func NopCloser(std agent.Agent) Agent { + return nopCloser{std} +} + +// Getter is a function used to get an agent instance. +type Getter func() (Agent, error) + // AgentServer is implementation of SSH agent server type AgentServer struct { - agent.Agent + getAgent Getter listener net.Listener path string } // NewServer returns new instance of agent server -func NewServer() *AgentServer { - return &AgentServer{Agent: agent.NewKeyring()} +func NewServer(getter Getter) *AgentServer { + return &AgentServer{getAgent: getter} } // ListenUnixSocket starts listening and serving agent assuming that @@ -77,10 +103,21 @@ func (a *AgentServer) Serve() error { continue } tempDelay = 0 + + // get an agent instance for serving this conn + instance, err := a.getAgent() + if err != nil { + log.Errorf("Failed to get agent: %v", err) + return trace.Wrap(err) + } + + // serve agent protocol against conn in a + // separate goroutine. go func() { - if err := agent.ServeAgent(a.Agent, conn); err != nil { + defer instance.Close() + if err := agent.ServeAgent(instance, conn); err != nil { if err != io.EOF { - log.Errorf(err.Error()) + log.Error(err) } } }()