Skip to content

Commit

Permalink
fix various race conditions when writing packets to closed clients or…
Browse files Browse the repository at this point in the history
… server sessions (#684)
  • Loading branch information
aler9 authored Jan 19, 2025
1 parent b2cfa93 commit ca62863
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 214 deletions.
21 changes: 10 additions & 11 deletions async_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,29 @@ type asyncProcessor struct {
buffer *ringbuffer.RingBuffer
stopError error

stopped chan struct{}
chStopped chan struct{}
}

func (w *asyncProcessor) initialize() {
w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
}

func (w *asyncProcessor) start() {
w.running = true
w.stopped = make(chan struct{})
go w.run()
}

func (w *asyncProcessor) stop() {
func (w *asyncProcessor) close() {
if w.running {
w.buffer.Close()
<-w.stopped
w.running = false
<-w.chStopped
}
}

func (w *asyncProcessor) start() {
w.running = true
w.chStopped = make(chan struct{})
go w.run()
}

func (w *asyncProcessor) run() {
w.stopError = w.runInner()
close(w.stopped)
close(w.chStopped)
}

func (w *asyncProcessor) runInner() error {
Expand Down
6 changes: 3 additions & 3 deletions async_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestAsyncProcessorStopAfterError(t *testing.T) {
func TestAsyncProcessorCloseAfterError(t *testing.T) {
p := &asyncProcessor{bufferSize: 8}
p.initialize()

Expand All @@ -17,8 +17,8 @@ func TestAsyncProcessorStopAfterError(t *testing.T) {

p.start()

<-p.stopped
<-p.chStopped
require.EqualError(t, p.stopError, "ok")

p.stop()
p.close()
}
94 changes: 68 additions & 26 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -340,6 +341,7 @@ type Client struct {
keepaliveTimer *time.Timer
closeError error
writer *asyncProcessor
writerMutex sync.RWMutex
reader *clientReader
timeDecoder *rtptime.GlobalDecoder2
mustClose bool
Expand Down Expand Up @@ -560,8 +562,8 @@ func (c *Client) runInner() error {
}()

chWriterError := func() chan struct{} {
if c.writer != nil && c.writer.running {
return c.writer.stopped
if c.writer != nil {
return c.writer.chStopped
}
return nil
}()
Expand Down Expand Up @@ -721,7 +723,7 @@ func (c *Client) handleServerRequest(req *base.Request) error {

func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord {
c.writer.stop()
c.destroyWriter()
c.stopTransportRoutines()
}

Expand Down Expand Up @@ -848,22 +850,6 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
}

func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped {
c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else {
// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
}

c.timeDecoder = rtptime.NewGlobalDecoder2()

for _, cm := range c.setuppedMedias {
Expand Down Expand Up @@ -913,6 +899,39 @@ func (c *Client) stopTransportRoutines() {
c.timeDecoder = nil
}

func (c *Client) createWriter() {
c.writerMutex.Lock()

c.writer = &asyncProcessor{
bufferSize: func() int {
if c.state == clientStateRecord || c.backChannelSetupped {
return c.WriteQueueSize
}

// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
return 8
}(),
}

c.writer.initialize()

c.writerMutex.Unlock()
}

func (c *Client) startWriter() {
c.writer.start()
}

func (c *Client) destroyWriter() {
c.writer.close()

c.writerMutex.Lock()
c.writer = nil
c.writerMutex.Unlock()
}

func (c *Client) connOpen() error {
if c.nconn != nil {
return nil
Expand Down Expand Up @@ -1389,7 +1408,7 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
}

err = cm.allocateUDPListeners(
err = cm.createUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
Expand Down Expand Up @@ -1544,7 +1563,7 @@ func (c *Client) doSetup(
readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}

err = cm.allocateUDPListeners(
err = cm.createUDPListeners(
true,
readIP,
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
Expand Down Expand Up @@ -1680,6 +1699,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {

c.state = clientStatePlay
c.startTransportRoutines()
c.createWriter()

// Range is mandatory in Parrot Streaming Server
if ra == nil {
Expand All @@ -1704,12 +1724,14 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header,
}, false)
if err != nil {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, err
}

if res.StatusCode != base.StatusOK {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{
Expand All @@ -1731,7 +1753,8 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
}

c.writer.start()
c.startWriter()

c.lastRange = ra

return res, nil
Expand Down Expand Up @@ -1761,26 +1784,29 @@ func (c *Client) doRecord() (*base.Response, error) {

c.state = clientStateRecord
c.startTransportRoutines()
c.createWriter()

res, err := c.do(&base.Request{
Method: base.Record,
URL: c.baseURL,
}, false)
if err != nil {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, err
}

if res.StatusCode != base.StatusOK {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}

c.writer.start()
c.startWriter()

return nil, nil
}
Expand Down Expand Up @@ -1808,19 +1834,21 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err
}

c.writer.stop()
c.destroyWriter()

res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false)
if err != nil {
c.writer.start()
c.createWriter()
c.startWriter()
return nil, err
}

if res.StatusCode != base.StatusOK {
c.writer.start()
c.createWriter()
c.startWriter()
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
Expand Down Expand Up @@ -1918,6 +1946,13 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
default:
}

c.writerMutex.RLock()
defer c.writerMutex.RUnlock()

if c.writer == nil {
return nil
}

cm := c.setuppedMedias[medi]
cf := cm.formats[pkt.PayloadType]

Expand Down Expand Up @@ -1946,6 +1981,13 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
default:
}

c.writerMutex.RLock()
defer c.writerMutex.RUnlock()

if c.writer == nil {
return nil
}

cm := c.setuppedMedias[medi]

ok := c.writer.push(func() error {
Expand Down
4 changes: 2 additions & 2 deletions client_media.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (cm *clientMedia) close() {
}
}

func (cm *clientMedia) allocateUDPListeners(
func (cm *clientMedia) createUDPListeners(
multicastEnable bool,
multicastSourceIP net.IP,
rtpAddress string,
Expand Down Expand Up @@ -94,7 +94,7 @@ func (cm *clientMedia) allocateUDPListeners(
}

var err error
cm.udpRTPListener, cm.udpRTCPListener, err = allocateUDPListenerPair(cm.c)
cm.udpRTPListener, cm.udpRTCPListener, err = createUDPListenerPair(cm.c)
return err
}

Expand Down
2 changes: 1 addition & 1 deletion client_play_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ func TestClientPlayRedirect(t *testing.T) {
}
}

func TestClientPlayPause(t *testing.T) {
func TestClientPlayPausePlay(t *testing.T) {
writeFrames := func(inTH *headers.Transport, conn *conn.Conn) (chan struct{}, chan struct{}) {
writerTerminate := make(chan struct{})
writerDone := make(chan struct{})
Expand Down
Loading

0 comments on commit ca62863

Please # to comment.