From 5c5141be35da354fa432ad388f4a7c3b5074a07a Mon Sep 17 00:00:00 2001 From: David Finkel Date: Wed, 16 Sep 2020 10:24:38 -0400 Subject: [PATCH] fake: track spawned goroutines Since AfterFunc configured callbacks are run in newly spawned goroutines, it's useful to at least have the guarantee that they've started when Advance and SetClock return. Also add a WaitGroup to track when those goroutines complete. --- fake/fake_clock.go | 42 ++++++++++++++++++++++++++++++++++++----- fake/fake_clock_test.go | 4 ++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/fake/fake_clock.go b/fake/fake_clock.go index e56345f..bbf5764 100644 --- a/fake/fake_clock.go +++ b/fake/fake_clock.go @@ -23,6 +23,10 @@ type Clock struct { // function to the wakeup time. (protected by mu) cbs map[*stopTimer]time.Time + // cbsWG tracks callback goroutines configured from AfterFunc (no mutex + // protection necessary) + cbsWG sync.WaitGroup + // cond is broadcasted() upon any sleep or wakeup event (mutations to // sleepers or cbs). cond sync.Cond @@ -63,7 +67,7 @@ func NewClock(initialTime time.Time) *Clock { } // returns the number of sleepers awoken -func (f *Clock) setClockLocked(t time.Time) int { +func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int { awoken := 0 for ch, target := range f.sleepers { if target.Sub(t) <= 0 { @@ -75,7 +79,13 @@ func (f *Clock) setClockLocked(t time.Time) int { cbsRun := 0 for s, target := range f.cbs { if target.Sub(t) <= 0 { - go s.f() + cbRunningWG.Add(1) + f.cbsWG.Add(1) + go func() { + defer f.cbsWG.Done() + cbRunningWG.Done() + s.f() + }() delete(f.cbs, s) cbsRun++ } @@ -88,19 +98,31 @@ func (f *Clock) setClockLocked(t time.Time) int { } // SetClock skips the FakeClock to the specified time (forward or backwards) +// The goroutines running newly-spawned functions scheduled with AfterFunc are +// guaranteed to have scheduled by the time this function returns. func (f *Clock) SetClock(t time.Time) int { + cbsWG := sync.WaitGroup{} + // Wait for callbacks to schedule before returning (but after the mutex + // is unlocked) + defer cbsWG.Wait() f.mu.Lock() defer f.mu.Unlock() - return f.setClockLocked(t) + return f.setClockLocked(t, &cbsWG) } // Advance skips the FakeClock forward by the specified duration (backwards if // negative) +// The goroutines running newly-spawned functions scheduled with AfterFunc are +// guaranteed to have scheduled by the time this function returns. func (f *Clock) Advance(dur time.Duration) int { + cbsWG := sync.WaitGroup{} + // Wait for callbacks to schedule before returning (but after the mutex + // is unlocked) + defer cbsWG.Wait() f.mu.Lock() defer f.mu.Unlock() t := f.current.Add(dur) - return f.setClockLocked(t) + return f.setClockLocked(t, &cbsWG) } // NumSleepers returns the number of goroutines waiting in SleepFor and SleepUntil @@ -307,7 +329,11 @@ func (f *Clock) AfterFunc(d time.Duration, cb func()) clocks.StopTimer { // run by the time the function has returned). if d <= 0 { f.callbackExecs++ - go cb() + f.cbsWG.Add(1) + go func() { + defer f.cbsWG.Done() + cb() + }() return doaStopTimer{} } wakeTime := f.current.Add(d) @@ -376,3 +402,9 @@ func (f *Clock) AwaitTimerAborts(n int) { f.cond.Wait() } } + +// WaitAfterFuncs blocks until all currently running AfterFunc callbacks +// return. +func (f *Clock) WaitAfterFuncs() { + f.cbsWG.Wait() +} diff --git a/fake/fake_clock_test.go b/fake/fake_clock_test.go index 0683a9f..16beb29 100644 --- a/fake/fake_clock_test.go +++ b/fake/fake_clock_test.go @@ -432,6 +432,7 @@ func TestFakeClockAfterFuncTimeWake(t *testing.T) { cbRun := make(chan struct{}) timerHandle := fc.AfterFunc(time.Hour, func() { close(cbRun) }) + fc.WaitAfterFuncs() <-aggCallbackWaitCh <-regCallbackWaitCh @@ -536,8 +537,10 @@ func TestFakeClockAfterFuncTimeAbort(t *testing.T) { cbRun := make(chan struct{}) timerHandle := fc.AfterFunc(time.Hour, func() { close(cbRun) }) + fc.WaitAfterFuncs() <-aggCallbackWaitCh <-regCallbackWaitCh + fc.WaitAfterFuncs() if regCBs := fc.NumRegisteredCallbacks(); regCBs != 1 { t.Errorf("unexpected registered callbacks: %d; expected 1", regCBs) @@ -617,6 +620,7 @@ func TestFakeClockAfterFuncNegDur(t *testing.T) { cbRun := make(chan struct{}) timerHandle := fc.AfterFunc(-time.Hour, func() { close(cbRun) }) + fc.WaitAfterFuncs() <-aggCallbackWaitCh <-cbRun