Skip to content

Commit d851d66

Browse files
authored
fix(consumer): stop retry task (#88)
when shutdown the service or task timeout
1 parent a8fae77 commit d851d66

File tree

2 files changed

+92
-3
lines changed

2 files changed

+92
-3
lines changed

consumer.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ func (s *Consumer) handle(m *job.Message) error {
4848

4949
// run custom process function
5050
var err error
51-
shouldRetry := true
52-
for shouldRetry {
51+
loop:
52+
for {
5353
if m.Task != nil {
5454
err = m.Task(ctx)
5555
} else {
@@ -62,7 +62,12 @@ func (s *Consumer) handle(m *job.Message) error {
6262
}
6363
m.RetryCount--
6464

65-
<-time.After(m.RetryDelay)
65+
select {
66+
case <-time.After(m.RetryDelay): // retry delay time
67+
case <-ctx.Done(): // timeout reached
68+
err = ctx.Err()
69+
break loop
70+
}
6671
}
6772

6873
done <- err

consumer_test.go

+84
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ func TestRetryCountWithNewMessage(t *testing.T) {
458458
m.EXPECT().Bytes().Return([]byte("test")).AnyTimes()
459459

460460
messages := make(chan string, 10)
461+
keep := make(chan struct{})
461462
count := 1
462463

463464
w := NewConsumer(
@@ -466,6 +467,7 @@ func TestRetryCountWithNewMessage(t *testing.T) {
466467
count++
467468
return errors.New("count not correct")
468469
}
470+
close(keep)
469471
messages <- string(m.Bytes())
470472
return nil
471473
}),
@@ -485,6 +487,8 @@ func TestRetryCountWithNewMessage(t *testing.T) {
485487
))
486488
assert.Len(t, messages, 0)
487489
q.Start()
490+
// wait retry twice.
491+
<-keep
488492
q.Release()
489493
assert.Len(t, messages, 1)
490494
}
@@ -502,12 +506,15 @@ func TestRetryCountWithNewTask(t *testing.T) {
502506
)
503507
assert.NoError(t, err)
504508

509+
keep := make(chan struct{})
510+
505511
assert.NoError(t, q.QueueTask(
506512
func(ctx context.Context) error {
507513
if count%3 != 0 {
508514
count++
509515
return errors.New("count not correct")
510516
}
517+
close(keep)
511518
messages <- "foobar"
512519
return nil
513520
},
@@ -516,6 +523,83 @@ func TestRetryCountWithNewTask(t *testing.T) {
516523
))
517524
assert.Len(t, messages, 0)
518525
q.Start()
526+
// wait retry twice.
527+
<-keep
519528
q.Release()
520529
assert.Len(t, messages, 1)
521530
}
531+
532+
func TestCancelRetryCountWithNewTask(t *testing.T) {
533+
messages := make(chan string, 10)
534+
count := 1
535+
536+
w := NewConsumer()
537+
538+
q, err := NewQueue(
539+
WithLogger(NewLogger()),
540+
WithWorker(w),
541+
WithWorkerCount(1),
542+
)
543+
assert.NoError(t, err)
544+
545+
assert.NoError(t, q.QueueTask(
546+
func(ctx context.Context) error {
547+
if count%3 != 0 {
548+
count++
549+
q.logger.Info("add count")
550+
return errors.New("count not correct")
551+
}
552+
messages <- "foobar"
553+
return nil
554+
},
555+
job.WithRetryCount(3),
556+
job.WithRetryDelay(100*time.Millisecond),
557+
))
558+
assert.Len(t, messages, 0)
559+
q.Start()
560+
time.Sleep(50 * time.Millisecond)
561+
q.Release()
562+
assert.Len(t, messages, 0)
563+
assert.Equal(t, 2, count)
564+
}
565+
566+
func TestCancelRetryCountWithNewMessage(t *testing.T) {
567+
controller := gomock.NewController(t)
568+
defer controller.Finish()
569+
570+
m := mocks.NewMockQueuedMessage(controller)
571+
m.EXPECT().Bytes().Return([]byte("test")).AnyTimes()
572+
573+
messages := make(chan string, 10)
574+
count := 1
575+
576+
w := NewConsumer(
577+
WithFn(func(ctx context.Context, m core.QueuedMessage) error {
578+
if count%3 != 0 {
579+
count++
580+
return errors.New("count not correct")
581+
}
582+
messages <- string(m.Bytes())
583+
return nil
584+
}),
585+
)
586+
587+
q, err := NewQueue(
588+
WithLogger(NewLogger()),
589+
WithWorker(w),
590+
WithWorkerCount(1),
591+
)
592+
assert.NoError(t, err)
593+
594+
assert.NoError(t, q.Queue(
595+
m,
596+
job.WithRetryCount(3),
597+
job.WithRetryDelay(100*time.Millisecond),
598+
))
599+
assert.Len(t, messages, 0)
600+
q.Start()
601+
time.Sleep(50 * time.Millisecond)
602+
q.Release()
603+
assert.Len(t, messages, 0)
604+
assert.Equal(t, 2, count)
605+
}

0 commit comments

Comments
 (0)