diff --git a/fake/fake_clock.go b/fake/fake_clock.go index e56345f..bbf5764 100644 --- a/fake/fake_clock.go +++ b/fake/fake_clock.go @@ -23,6 +23,10 @@ type Clock struct { // function to the wakeup time. (protected by mu) cbs map[*stopTimer]time.Time + // cbsWG tracks callback goroutines configured from AfterFunc (no mutex + // protection necessary) + cbsWG sync.WaitGroup + // cond is broadcasted() upon any sleep or wakeup event (mutations to // sleepers or cbs). cond sync.Cond @@ -63,7 +67,7 @@ func NewClock(initialTime time.Time) *Clock { } // returns the number of sleepers awoken -func (f *Clock) setClockLocked(t time.Time) int { +func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int { awoken := 0 for ch, target := range f.sleepers { if target.Sub(t) <= 0 { @@ -75,7 +79,13 @@ func (f *Clock) setClockLocked(t time.Time) int { cbsRun := 0 for s, target := range f.cbs { if target.Sub(t) <= 0 { - go s.f() + cbRunningWG.Add(1) + f.cbsWG.Add(1) + go func() { + defer f.cbsWG.Done() + cbRunningWG.Done() + s.f() + }() delete(f.cbs, s) cbsRun++ } @@ -88,19 +98,31 @@ func (f *Clock) setClockLocked(t time.Time) int { } // SetClock skips the FakeClock to the specified time (forward or backwards) +// The goroutines running newly-spawned functions scheduled with AfterFunc are +// guaranteed to have scheduled by the time this function returns. func (f *Clock) SetClock(t time.Time) int { + cbsWG := sync.WaitGroup{} + // Wait for callbacks to schedule before returning (but after the mutex + // is unlocked) + defer cbsWG.Wait() f.mu.Lock() defer f.mu.Unlock() - return f.setClockLocked(t) + return f.setClockLocked(t, &cbsWG) } // Advance skips the FakeClock forward by the specified duration (backwards if // negative) +// The goroutines running newly-spawned functions scheduled with AfterFunc are +// guaranteed to have scheduled by the time this function returns. func (f *Clock) Advance(dur time.Duration) int { + cbsWG := sync.WaitGroup{} + // Wait for callbacks to schedule before returning (but after the mutex + // is unlocked) + defer cbsWG.Wait() f.mu.Lock() defer f.mu.Unlock() t := f.current.Add(dur) - return f.setClockLocked(t) + return f.setClockLocked(t, &cbsWG) } // NumSleepers returns the number of goroutines waiting in SleepFor and SleepUntil @@ -307,7 +329,11 @@ func (f *Clock) AfterFunc(d time.Duration, cb func()) clocks.StopTimer { // run by the time the function has returned). if d <= 0 { f.callbackExecs++ - go cb() + f.cbsWG.Add(1) + go func() { + defer f.cbsWG.Done() + cb() + }() return doaStopTimer{} } wakeTime := f.current.Add(d) @@ -376,3 +402,9 @@ func (f *Clock) AwaitTimerAborts(n int) { f.cond.Wait() } } + +// WaitAfterFuncs blocks until all currently running AfterFunc callbacks +// return. +func (f *Clock) WaitAfterFuncs() { + f.cbsWG.Wait() +} diff --git a/fake/fake_clock_test.go b/fake/fake_clock_test.go index 0683a9f..16beb29 100644 --- a/fake/fake_clock_test.go +++ b/fake/fake_clock_test.go @@ -432,6 +432,7 @@ func TestFakeClockAfterFuncTimeWake(t *testing.T) { cbRun := make(chan struct{}) timerHandle := fc.AfterFunc(time.Hour, func() { close(cbRun) }) + fc.WaitAfterFuncs() <-aggCallbackWaitCh <-regCallbackWaitCh @@ -536,8 +537,10 @@ func TestFakeClockAfterFuncTimeAbort(t *testing.T) { cbRun := make(chan struct{}) timerHandle := fc.AfterFunc(time.Hour, func() { close(cbRun) }) + fc.WaitAfterFuncs() <-aggCallbackWaitCh <-regCallbackWaitCh + fc.WaitAfterFuncs() if regCBs := fc.NumRegisteredCallbacks(); regCBs != 1 { t.Errorf("unexpected registered callbacks: %d; expected 1", regCBs) @@ -617,6 +620,7 @@ func TestFakeClockAfterFuncNegDur(t *testing.T) { cbRun := make(chan struct{}) timerHandle := fc.AfterFunc(-time.Hour, func() { close(cbRun) }) + fc.WaitAfterFuncs() <-aggCallbackWaitCh <-cbRun