Skip to content

Commit fb009b1

Browse files
authored
chore(consumer): handle all job before shutdown. (#68)
1 parent 965207c commit fb009b1

File tree

2 files changed

+77
-5
lines changed

2 files changed

+77
-5
lines changed

consumer.go

+9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type Consumer struct {
2020
taskQueue chan core.QueuedMessage
2121
runFunc func(context.Context, core.QueuedMessage) error
2222
stop chan struct{}
23+
exit chan struct{}
2324
logger Logger
2425
stopOnce sync.Once
2526
stopFlag int32
@@ -101,6 +102,9 @@ func (s *Consumer) Shutdown() error {
101102
s.stopOnce.Do(func() {
102103
close(s.stop)
103104
close(s.taskQueue)
105+
if len(s.taskQueue) > 0 {
106+
<-s.exit
107+
}
104108
})
105109
return nil
106110
}
@@ -127,6 +131,10 @@ loop:
127131
select {
128132
case task, ok := <-s.taskQueue:
129133
if !ok {
134+
select {
135+
case s.exit <- struct{}{}:
136+
default:
137+
}
130138
return nil, ErrQueueHasBeenClosed
131139
}
132140
return task, nil
@@ -147,6 +155,7 @@ func NewConsumer(opts ...Option) *Consumer {
147155
w := &Consumer{
148156
taskQueue: make(chan core.QueuedMessage, o.queueSize),
149157
stop: make(chan struct{}),
158+
exit: make(chan struct{}),
150159
logger: o.logger,
151160
runFunc: o.fn,
152161
}

consumer_test.go

+68-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"time"
1111

1212
"github.com/golang-queue/queue/core"
13+
"github.com/golang-queue/queue/mocks"
14+
"github.com/golang/mock/gomock"
1315

1416
"github.com/stretchr/testify/assert"
1517
)
@@ -237,8 +239,6 @@ func TestHandleTimeout(t *testing.T) {
237239
done <- w.handle(job)
238240
}()
239241

240-
assert.NoError(t, w.Shutdown())
241-
242242
err = <-done
243243
assert.Error(t, err)
244244
assert.Equal(t, context.DeadlineExceeded, err)
@@ -276,8 +276,6 @@ func TestJobComplete(t *testing.T) {
276276
done <- w.handle(job)
277277
}()
278278

279-
assert.NoError(t, w.Shutdown())
280-
281279
err = <-done
282280
assert.Error(t, err)
283281
assert.Equal(t, errors.New("job completed"), err)
@@ -308,7 +306,7 @@ func TestTaskJobComplete(t *testing.T) {
308306
go func() {
309307
done <- w.handle(job)
310308
}()
311-
assert.NoError(t, w.Shutdown())
309+
312310
err = <-done
313311
assert.NoError(t, err)
314312

@@ -385,3 +383,68 @@ func TestDecreaseWorkerCount(t *testing.T) {
385383
assert.Equal(t, 2, q.BusyWorkers())
386384
q.Release()
387385
}
386+
387+
func TestHandleAllJobBeforeShutdownConsumer(t *testing.T) {
388+
controller := gomock.NewController(t)
389+
defer controller.Finish()
390+
391+
m := mocks.NewMockQueuedMessage(controller)
392+
393+
w := NewConsumer(
394+
WithFn(func(ctx context.Context, m core.QueuedMessage) error {
395+
time.Sleep(10 * time.Millisecond)
396+
return nil
397+
}),
398+
)
399+
400+
done := make(chan struct{})
401+
assert.NoError(t, w.Queue(m))
402+
assert.NoError(t, w.Queue(m))
403+
go func() {
404+
assert.NoError(t, w.Shutdown())
405+
done <- struct{}{}
406+
}()
407+
408+
task, err := w.Request()
409+
assert.NotNil(t, task)
410+
assert.NoError(t, err)
411+
task, err = w.Request()
412+
assert.NotNil(t, task)
413+
assert.NoError(t, err)
414+
task, err = w.Request()
415+
assert.Nil(t, task)
416+
assert.True(t, errors.Is(err, ErrQueueHasBeenClosed))
417+
<-done
418+
}
419+
420+
func TestHandleAllJobBeforeShutdownConsumerInQueue(t *testing.T) {
421+
controller := gomock.NewController(t)
422+
defer controller.Finish()
423+
424+
m := mocks.NewMockQueuedMessage(controller)
425+
m.EXPECT().Bytes().Return([]byte("test")).AnyTimes()
426+
427+
messages := make(chan string, 10)
428+
429+
w := NewConsumer(
430+
WithFn(func(ctx context.Context, m core.QueuedMessage) error {
431+
time.Sleep(10 * time.Millisecond)
432+
messages <- string(m.Bytes())
433+
return nil
434+
}),
435+
)
436+
437+
q, err := NewQueue(
438+
WithLogger(NewLogger()),
439+
WithWorker(w),
440+
WithWorkerCount(1),
441+
)
442+
assert.NoError(t, err)
443+
444+
assert.NoError(t, q.Queue(m))
445+
assert.NoError(t, q.Queue(m))
446+
assert.Len(t, messages, 0)
447+
q.Start()
448+
q.Release()
449+
assert.Len(t, messages, 2)
450+
}

0 commit comments

Comments
 (0)