From d694fd3428c1c342e1a66d98c51fa8b325062044 Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 15 Mar 2022 13:57:46 +0100 Subject: [PATCH] Fix quit on ctrlc, race panic, atomic load align in session IO (#11112) --- lib/srv/termmanager.go | 38 ++++++++++++++++------------ lib/srv/termmanager_test.go | 50 +++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 lib/srv/termmanager_test.go diff --git a/lib/srv/termmanager.go b/lib/srv/termmanager.go index babf05dfa2a8a..b9bf958ca3c33 100644 --- a/lib/srv/termmanager.go +++ b/lib/srv/termmanager.go @@ -31,12 +31,15 @@ const maxHistory = 1000 // - history scrollback for new clients // - stream breaking type TermManager struct { + // These two fields need to be first in the struct so that they are 64-bit aligned which is a requirement + // for atomic operations on certain architectures. + countWritten uint64 + countRead uint64 + mu sync.Mutex writers map[string]io.Writer - readerState map[string]*int32 + readerState map[string]bool OnWriteError func(idString string, err error) - countWritten uint64 - countRead uint64 // buffer is used to buffer writes when turned off buffer []byte on bool @@ -48,7 +51,7 @@ type TermManager struct { // we only support one concurrent reader so this isn't mutex protected remaining []byte readStateUpdate *sync.Cond - closed *int32 + closed bool lastWasBroadcast bool terminateNotifier chan struct{} } @@ -57,8 +60,8 @@ type TermManager struct { func NewTermManager() *TermManager { return &TermManager{ writers: make(map[string]io.Writer), - readerState: make(map[string]*int32), - closed: new(int32), + readerState: make(map[string]bool), + closed: false, readStateUpdate: sync.NewCond(&sync.Mutex{}), incoming: make(chan []byte, 100), terminateNotifier: make(chan struct{}), @@ -214,8 +217,7 @@ func (g *TermManager) DeleteWriter(name string) { } func (g *TermManager) AddReader(name string, r io.Reader) { - readerState := new(int32) - g.readerState[name] = readerState + g.readerState[name] = false go func() { for { @@ -231,21 +233,24 @@ func (g *TermManager) AddReader(name string, r io.Reader) { // This is the ASCII control code for CTRL+C. if b == 0x03 { g.mu.Lock() - if !g.on { + if !g.on && !g.closed { select { case g.terminateNotifier <- struct{}{}: default: } } g.mu.Unlock() - return + break } } g.incoming <- buf[:n] - if atomic.LoadInt32(g.closed) == 1 || atomic.LoadInt32(readerState) == 1 { + g.mu.Lock() + if g.closed || g.readerState[name] { + g.mu.Unlock() return } + g.mu.Unlock() } }() } @@ -253,10 +258,7 @@ func (g *TermManager) AddReader(name string, r io.Reader) { func (g *TermManager) DeleteReader(name string) { g.mu.Lock() defer g.mu.Unlock() - - if g.readerState[name] != nil { - atomic.StoreInt32(g.readerState[name], 1) - } + g.readerState[name] = true } func (g *TermManager) CountWritten() uint64 { @@ -268,7 +270,11 @@ func (g *TermManager) CountRead() uint64 { } func (g *TermManager) Close() { - if atomic.CompareAndSwapInt32(g.closed, 0, 1) { + g.mu.Lock() + defer g.mu.Unlock() + + if !g.closed { + g.closed = true close(g.terminateNotifier) } } diff --git a/lib/srv/termmanager_test.go b/lib/srv/termmanager_test.go new file mode 100644 index 0000000000000..feb5667bab47c --- /dev/null +++ b/lib/srv/termmanager_test.go @@ -0,0 +1,50 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package srv + +import ( + "io" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCTRLCPassthrough(t *testing.T) { + m := NewTermManager() + m.On() + r, w := io.Pipe() + m.AddReader("foo", r) + go w.Write([]byte("\x03")) + buf := make([]byte, 1) + _, err := m.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("\x03"), buf) +} + +func TestCTRLCCapture(t *testing.T) { + m := NewTermManager() + r, w := io.Pipe() + m.AddReader("foo", r) + go w.Write([]byte("\x03")) + + select { + case <-m.TerminateNotifier(): + case <-time.After(time.Second * 10): + t.Fatal("terminateNotifier should've seen an event") + } +}