Skip to content

Commit

Permalink
fix agent forwarding for multi-session connections
Browse files Browse the repository at this point in the history
Changes the lifetime of agent forwarding to be scoped
to the underlying ssh connection, instead of the
specific ssh channel which initially passed the agent
forwarding request.
  • Loading branch information
fspmarshall committed Apr 29, 2020
1 parent bdd388e commit c341d2b
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 85 deletions.
6 changes: 3 additions & 3 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1265,11 +1265,11 @@ func (s *discardServer) Stop() {
s.sshServer.Close()
}

func (s *discardServer) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, newChannel ssh.NewChannel) {
func (s *discardServer) HandleNewChan(ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) {
channel, reqs, err := newChannel.Accept()
if err != nil {
sconn.Close()
conn.Close()
ccx.ServerConn.Close()
ccx.NetConn.Close()
return
}

Expand Down
4 changes: 2 additions & 2 deletions lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *MuxSuite) TestMultiplexing(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
nch.Reject(ssh.Prohibited, "nothing to see here")
})
Expand Down Expand Up @@ -381,7 +381,7 @@ func (s *MuxSuite) TestDisableTLS(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
nch.Reject(ssh.Prohibited, "nothing to see here")
})
Expand Down
5 changes: 3 additions & 2 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,12 @@ func (s *server) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}

func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
func (s *server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) {
// Apply read/write timeouts to the server connection.
conn = utils.ObeyIdleTimeout(conn,
conn := utils.ObeyIdleTimeout(ccx.NetConn,
s.offlineThreshold,
"reverse tunnel server")
sconn := ccx.ServerConn

channelType := nch.ChannelType()
switch channelType {
Expand Down
65 changes: 30 additions & 35 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ type ServerContext struct {

sync.RWMutex

Parent *sshutils.ConnectionContext

// env is a list of environment variables passed to the session.
env map[string]string

Expand All @@ -186,12 +188,6 @@ type ServerContext struct {
// term holds PTY if it was requested by the session.
term Terminal

// agent is a client to remote SSH agent.
agent agent.Agent

// agentCh is SSH channel using SSH agent protocol.
agentChannel ssh.Channel

// session holds the active session (if there's an active one).
session *session

Expand Down Expand Up @@ -291,7 +287,7 @@ type ServerContext struct {

// NewServerContext creates a new *ServerContext which is used to pass and
// manage resources.
func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext) (*ServerContext, error) {
func NewServerContext(ccx *sshutils.ConnectionContext, srv Server, identityContext IdentityContext) (*ServerContext, error) {
clusterConfig, err := srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -300,13 +296,15 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
cancelContext, cancel := context.WithCancel(context.TODO())

ctx := &ServerContext{
Parent: ccx,
id: int(atomic.AddInt32(&ctxID, int32(1))),
env: make(map[string]string),
srv: srv,
Conn: conn,
Connection: ccx.NetConn,
Conn: ccx.ServerConn,
ExecResultCh: make(chan ExecResult, 10),
SubsystemResultCh: make(chan SubsystemResult, 10),
ClusterName: conn.Permissions.Extensions[utils.CertTeleportClusterName],
ClusterName: ccx.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName],
ClusterConfig: clusterConfig,
Identity: identityContext,
clientIdleTimeout: identityContext.RoleSet.AdjustClientIdleTimeout(clusterConfig.GetClientIdleTimeout()),
Expand All @@ -320,8 +318,8 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
}

fields := log.Fields{
"local": conn.LocalAddr(),
"remote": conn.RemoteAddr(),
"local": ctx.Conn.LocalAddr(),
"remote": ctx.Conn.RemoteAddr(),
"login": ctx.Identity.Login,
"teleportUser": ctx.Identity.TeleportUser,
"id": ctx.id,
Expand All @@ -343,7 +341,7 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
ClientIdleTimeout: ctx.clientIdleTimeout,
Clock: ctx.srv.GetClock(),
Tracker: ctx,
Conn: conn,
Conn: ctx.Conn,
Context: cancelContext,
TeleportUser: ctx.Identity.TeleportUser,
Login: ctx.Identity.Login,
Expand Down Expand Up @@ -374,6 +372,9 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
ctx.AddCloser(ctx.contr)
ctx.AddCloser(ctx.contw)

// gather environment variables from parent.
ctx.ImportParentEnv()

return ctx, nil
}

Expand Down Expand Up @@ -447,30 +448,22 @@ func (c *ServerContext) AddCloser(closer io.Closer) {
c.closers = append(c.closers, closer)
}

// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent.
// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent,
// or nil if no agent is available in this context.
func (c *ServerContext) GetAgent() agent.Agent {
c.RLock()
defer c.RUnlock()
return c.agent
if c.Parent == nil {
return nil
}
return c.Parent.GetAgent()
}

// GetAgentChannel returns the channel over which communication with the agent occurs.
// GetAgentChannel returns the channel over which communication with the agent occurs,
// or nil if no agent is available in this context.
func (c *ServerContext) GetAgentChannel() ssh.Channel {
c.RLock()
defer c.RUnlock()
return c.agentChannel
}

// SetAgent sets the agent and channel over which communication with the agent occurs.
func (c *ServerContext) SetAgent(a agent.Agent, channel ssh.Channel) {
c.Lock()
defer c.Unlock()
if c.agentChannel != nil {
c.Infof("closing previous agent channel")
c.agentChannel.Close()
if c.Parent == nil {
return nil
}
c.agentChannel = channel
c.agent = a
return c.Parent.GetAgentChannel()
}

// GetTerm returns a Terminal.
Expand Down Expand Up @@ -500,6 +493,12 @@ func (c *ServerContext) GetEnv(key string) (string, bool) {
return val, ok
}

// ImportParentEnv is used to re-synchronize env vars after
// parent context has been updated.
func (c *ServerContext) ImportParentEnv() {
c.Parent.ExportEnv(c.env)
}

// takeClosers returns all resources that should be closed and sets the properties to null
// we do this to avoid calling Close() under lock to avoid potential deadlocks
func (c *ServerContext) takeClosers() []io.Closer {
Expand All @@ -512,10 +511,6 @@ func (c *ServerContext) takeClosers() []io.Closer {
closers = append(closers, c.term)
c.term = nil
}
if c.agentChannel != nil {
closers = append(closers, c.agentChannel)
c.agentChannel = nil
}
closers = append(closers, c.closers...)
c.closers = nil
return closers
Expand Down
11 changes: 9 additions & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type Server struct {
// forwarding, subsystems.
remoteClient *ssh.Client

// connectionContext is used to construct ServerContext instances
// and supports registration of connection-scoped resource closers.
connectionContext *sshutils.ConnectionContext

// identityContext holds identity information about the user that has
// authenticated on sconn (like system login, Teleport username, roles).
identityContext srv.IdentityContext
Expand Down Expand Up @@ -435,6 +439,8 @@ func (s *Server) Serve() {
}
s.sconn = sconn

s.connectionContext = sshutils.NewConnectionContext(s.serverConn, s.sconn)

// Take connection and extract identity information for the user from it.
s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn)
if err != nil {
Expand Down Expand Up @@ -488,6 +494,7 @@ func (s *Server) Close() error {
s.serverConn,
s.targetConn,
s.remoteClient,
s.connectionContext,
}

var errs []error
Expand Down Expand Up @@ -646,7 +653,7 @@ func (s *Server) handleChannel(nch ssh.NewChannel) {
func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down Expand Up @@ -713,7 +720,7 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) {
// There is no need for the forwarding server to initiate disconnects,
// based on teleport business logic, because this logic is already
// done on the server's terminating side.
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down
Loading

0 comments on commit c341d2b

Please # to comment.