Skip to content

Commit 54d4299

Browse files
committed
context: add AfterFunc
Add an AfterFunc function, which registers a function to run after a context has been canceled. Add support for contexts that implement an AfterFunc method, which can be used to avoid the need to start a new goroutine watching the Done channel when propagating cancellation signals. Fixes #57928 Change-Id: If0b2cdcc4332961276a1ff57311338e74916259c Reviewed-on: https://go-review.googlesource.com/c/go/+/482695 TryBot-Result: Gopher Robot <gobot@golang.org> Run-TryBot: Damien Neil <dneil@google.com> Reviewed-by: Sameer Ajmani <sameer@golang.org>
1 parent 9d53d7a commit 54d4299

File tree

6 files changed

+533
-44
lines changed

6 files changed

+533
-44
lines changed

api/next/57928.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pkg context, func AfterFunc(Context, func()) func() bool #57928

src/context/afterfunc_test.go

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package context_test
6+
7+
import (
8+
"context"
9+
"sync"
10+
"testing"
11+
"time"
12+
)
13+
14+
// afterFuncContext is a context that's not one of the types
15+
// defined in context.go, that supports registering AfterFuncs.
16+
type afterFuncContext struct {
17+
mu sync.Mutex
18+
afterFuncs map[*struct{}]func()
19+
done chan struct{}
20+
err error
21+
}
22+
23+
func newAfterFuncContext() context.Context {
24+
return &afterFuncContext{}
25+
}
26+
27+
func (c *afterFuncContext) Deadline() (time.Time, bool) {
28+
return time.Time{}, false
29+
}
30+
31+
func (c *afterFuncContext) Done() <-chan struct{} {
32+
c.mu.Lock()
33+
defer c.mu.Unlock()
34+
if c.done == nil {
35+
c.done = make(chan struct{})
36+
}
37+
return c.done
38+
}
39+
40+
func (c *afterFuncContext) Err() error {
41+
c.mu.Lock()
42+
defer c.mu.Unlock()
43+
return c.err
44+
}
45+
46+
func (c *afterFuncContext) Value(key any) any {
47+
return nil
48+
}
49+
50+
func (c *afterFuncContext) AfterFunc(f func()) func() bool {
51+
c.mu.Lock()
52+
defer c.mu.Unlock()
53+
k := &struct{}{}
54+
if c.afterFuncs == nil {
55+
c.afterFuncs = make(map[*struct{}]func())
56+
}
57+
c.afterFuncs[k] = f
58+
return func() bool {
59+
c.mu.Lock()
60+
defer c.mu.Unlock()
61+
_, ok := c.afterFuncs[k]
62+
delete(c.afterFuncs, k)
63+
return ok
64+
}
65+
}
66+
67+
func (c *afterFuncContext) cancel(err error) {
68+
c.mu.Lock()
69+
defer c.mu.Unlock()
70+
if c.err != nil {
71+
return
72+
}
73+
c.err = err
74+
for _, f := range c.afterFuncs {
75+
go f()
76+
}
77+
c.afterFuncs = nil
78+
}
79+
80+
func TestCustomContextAfterFuncCancel(t *testing.T) {
81+
ctx0 := &afterFuncContext{}
82+
ctx1, cancel := context.WithCancel(ctx0)
83+
defer cancel()
84+
ctx0.cancel(context.Canceled)
85+
<-ctx1.Done()
86+
}
87+
88+
func TestCustomContextAfterFuncTimeout(t *testing.T) {
89+
ctx0 := &afterFuncContext{}
90+
ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
91+
defer cancel()
92+
ctx0.cancel(context.Canceled)
93+
<-ctx1.Done()
94+
}
95+
96+
func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
97+
ctx0 := &afterFuncContext{}
98+
donec := make(chan struct{})
99+
stop := context.AfterFunc(ctx0, func() {
100+
close(donec)
101+
})
102+
defer stop()
103+
ctx0.cancel(context.Canceled)
104+
<-donec
105+
}
106+
107+
func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
108+
ctx0 := &afterFuncContext{}
109+
_, cancel := context.WithCancel(ctx0)
110+
if got, want := len(ctx0.afterFuncs), 1; got != want {
111+
t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
112+
}
113+
cancel()
114+
if got, want := len(ctx0.afterFuncs), 0; got != want {
115+
t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
116+
}
117+
}
118+
119+
func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
120+
ctx0 := &afterFuncContext{}
121+
_, cancel := context.WithTimeout(ctx0, veryLongDuration)
122+
if got, want := len(ctx0.afterFuncs), 1; got != want {
123+
t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
124+
}
125+
cancel()
126+
if got, want := len(ctx0.afterFuncs), 0; got != want {
127+
t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
128+
}
129+
}
130+
131+
func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
132+
ctx0 := &afterFuncContext{}
133+
stop := context.AfterFunc(ctx0, func() {})
134+
if got, want := len(ctx0.afterFuncs), 1; got != want {
135+
t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
136+
}
137+
stop()
138+
if got, want := len(ctx0.afterFuncs), 0; got != want {
139+
t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
140+
}
141+
}

src/context/context.go

+126-41
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ func withCancel(parent Context) *cancelCtx {
269269
if parent == nil {
270270
panic("cannot create context from nil parent")
271271
}
272-
c := &cancelCtx{Context: parent}
273-
propagateCancel(parent, c)
272+
c := &cancelCtx{}
273+
c.propagateCancel(parent, c)
274274
return c
275275
}
276276

@@ -289,48 +289,72 @@ func Cause(c Context) error {
289289
return nil
290290
}
291291

292-
// goroutines counts the number of goroutines ever created; for testing.
293-
var goroutines atomic.Int32
294-
295-
// propagateCancel arranges for child to be canceled when parent is.
296-
func propagateCancel(parent Context, child canceler) {
297-
done := parent.Done()
298-
if done == nil {
299-
return // parent is never canceled
292+
// AfterFunc arranges to call f in its own goroutine after ctx is done
293+
// (cancelled or timed out).
294+
// If ctx is already done, AfterFunc calls f immediately in its own goroutine.
295+
//
296+
// Multiple calls to AfterFunc on a context operate independently;
297+
// one does not replace another.
298+
//
299+
// Calling the returned stop function stops the association of ctx with f.
300+
// It returns true if the call stopped f from being run.
301+
// If stop returns false,
302+
// either the context is done and f has been started in its own goroutine;
303+
// or f was already stopped.
304+
// The stop function does not wait for f to complete before returning.
305+
// If the caller needs to know whether f is completed,
306+
// it must coordinate with f explicitly.
307+
//
308+
// If ctx has a "AfterFunc(func()) func() bool" method,
309+
// AfterFunc will use it to schedule the call.
310+
func AfterFunc(ctx Context, f func()) (stop func() bool) {
311+
a := &afterFuncCtx{
312+
f: f,
313+
}
314+
a.cancelCtx.propagateCancel(ctx, a)
315+
return func() bool {
316+
stopped := false
317+
a.once.Do(func() {
318+
stopped = true
319+
})
320+
if stopped {
321+
a.cancel(true, Canceled, nil)
322+
}
323+
return stopped
300324
}
325+
}
301326

302-
select {
303-
case <-done:
304-
// parent is already canceled
305-
child.cancel(false, parent.Err(), Cause(parent))
306-
return
307-
default:
308-
}
327+
type afterFuncer interface {
328+
AfterFunc(func()) func() bool
329+
}
309330

310-
if p, ok := parentCancelCtx(parent); ok {
311-
p.mu.Lock()
312-
if p.err != nil {
313-
// parent has already been canceled
314-
child.cancel(false, p.err, p.cause)
315-
} else {
316-
if p.children == nil {
317-
p.children = make(map[canceler]struct{})
318-
}
319-
p.children[child] = struct{}{}
320-
}
321-
p.mu.Unlock()
322-
} else {
323-
goroutines.Add(1)
324-
go func() {
325-
select {
326-
case <-parent.Done():
327-
child.cancel(false, parent.Err(), Cause(parent))
328-
case <-child.Done():
329-
}
330-
}()
331+
type afterFuncCtx struct {
332+
cancelCtx
333+
once sync.Once // either starts running f or stops f from running
334+
f func()
335+
}
336+
337+
func (a *afterFuncCtx) cancel(removeFromParent bool, err, cause error) {
338+
a.cancelCtx.cancel(false, err, cause)
339+
if removeFromParent {
340+
removeChild(a.Context, a)
331341
}
342+
a.once.Do(func() {
343+
go a.f()
344+
})
332345
}
333346

347+
// A stopCtx is used as the parent context of a cancelCtx when
348+
// an AfterFunc has been registered with the parent.
349+
// It holds the stop function used to unregister the AfterFunc.
350+
type stopCtx struct {
351+
Context
352+
stop func() bool
353+
}
354+
355+
// goroutines counts the number of goroutines ever created; for testing.
356+
var goroutines atomic.Int32
357+
334358
// &cancelCtxKey is the key that a cancelCtx returns itself for.
335359
var cancelCtxKey int
336360

@@ -358,6 +382,10 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) {
358382

359383
// removeChild removes a context from its parent.
360384
func removeChild(parent Context, child canceler) {
385+
if s, ok := parent.(stopCtx); ok {
386+
s.stop()
387+
return
388+
}
361389
p, ok := parentCancelCtx(parent)
362390
if !ok {
363391
return
@@ -424,6 +452,64 @@ func (c *cancelCtx) Err() error {
424452
return err
425453
}
426454

455+
// propagateCancel arranges for child to be canceled when parent is.
456+
// It sets the parent context of cancelCtx.
457+
func (c *cancelCtx) propagateCancel(parent Context, child canceler) {
458+
c.Context = parent
459+
460+
done := parent.Done()
461+
if done == nil {
462+
return // parent is never canceled
463+
}
464+
465+
select {
466+
case <-done:
467+
// parent is already canceled
468+
child.cancel(false, parent.Err(), Cause(parent))
469+
return
470+
default:
471+
}
472+
473+
if p, ok := parentCancelCtx(parent); ok {
474+
// parent is a *cancelCtx, or derives from one.
475+
p.mu.Lock()
476+
if p.err != nil {
477+
// parent has already been canceled
478+
child.cancel(false, p.err, p.cause)
479+
} else {
480+
if p.children == nil {
481+
p.children = make(map[canceler]struct{})
482+
}
483+
p.children[child] = struct{}{}
484+
}
485+
p.mu.Unlock()
486+
return
487+
}
488+
489+
if a, ok := parent.(afterFuncer); ok {
490+
// parent implements an AfterFunc method.
491+
c.mu.Lock()
492+
stop := a.AfterFunc(func() {
493+
child.cancel(false, parent.Err(), Cause(parent))
494+
})
495+
c.Context = stopCtx{
496+
Context: parent,
497+
stop: stop,
498+
}
499+
c.mu.Unlock()
500+
return
501+
}
502+
503+
goroutines.Add(1)
504+
go func() {
505+
select {
506+
case <-parent.Done():
507+
child.cancel(false, parent.Err(), Cause(parent))
508+
case <-child.Done():
509+
}
510+
}()
511+
}
512+
427513
type stringer interface {
428514
String() string
429515
}
@@ -533,10 +619,9 @@ func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, Cance
533619
return WithCancel(parent)
534620
}
535621
c := &timerCtx{
536-
cancelCtx: cancelCtx{Context: parent},
537-
deadline: d,
622+
deadline: d,
538623
}
539-
propagateCancel(parent, c)
624+
c.cancelCtx.propagateCancel(parent, c)
540625
dur := time.Until(d)
541626
if dur <= 0 {
542627
c.cancel(true, DeadlineExceeded, cause) // deadline has already passed

0 commit comments

Comments
 (0)