From bb20704c3451fd2e9669ececabcda159fbce69fd Mon Sep 17 00:00:00 2001 From: Spencer Reeves Date: Sat, 14 May 2022 14:05:40 -0700 Subject: [PATCH] test for thread.pool --- thread/pool.go | 19 ++++++++++++------- thread/pool_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/thread/pool.go b/thread/pool.go index 0c5f01c..9d0cfa6 100644 --- a/thread/pool.go +++ b/thread/pool.go @@ -10,6 +10,7 @@ type Pool[K any] struct { Consume func(elem *K) error OnError func(elem *K, e error) Closed bool + closeCh chan struct{} waitGroup sync.WaitGroup workers []*Thread } @@ -21,31 +22,35 @@ func NewPool[K any](count int, consumerChan chan K, consumerFn func(elem *K) err Consume: consumerFn, OnError: onError, Closed: false, + closeCh: make(chan struct{}), } - p.waitGroup.Add(count) for i := 0; i < count; i++ { - p.workers = append(p.workers, Consumer[K](&p.waitGroup, consumerChan, consumerFn, onError)) + p.workers = append(p.workers, Consumer[K](&p.waitGroup, p.closeCh, consumerChan, consumerFn, onError)) } return &p } -// Close Has side effects! Closes the channel and optionally waits for threads to indicate they have closed. func (p *Pool[K]) Close(block bool) error { - if !p.Closed { - close(p.ConsumerChannel) + if p.Closed { + return nil + } else { + for i := 0; i < p.Count; i++ { + p.closeCh <- struct{}{} + } } - p.Closed = true if block { p.waitGroup.Wait() + close(p.closeCh) } + p.Closed = true return nil } -func (p *Pool[K]) Metrics(aggregated bool) []*Metric { +func (p *Pool[K]) Metrics() []*Metric { var metrics []*Metric for _, w := range p.workers { metrics = append(metrics, w.Metrics) diff --git a/thread/pool_test.go b/thread/pool_test.go index e69de29..3de67a5 100644 --- a/thread/pool_test.go +++ b/thread/pool_test.go @@ -0,0 +1,44 @@ +package thread_test + +import ( + "github.com/spencerreeves/snippets/thread" + "testing" + "time" +) + +func TestNewPool(t *testing.T) { + cnt, errs, ch := 0, 0, make(chan int, 1) + + // Verify we can create a pool + p := thread.NewPool[int](2, ch, incFn(&cnt), errCntFn(&errs)) + + // Verify pool consumes data + ch <- 1 + ch <- 1 + time.Sleep(time.Millisecond) + if cnt != 2 { + t.Error("invalid count") + } + + // Verify all threads are closed when we close a pool + if err := p.Close(true); err != nil || !p.Closed { + t.Error("pool failed to close") + } + + ch <- 1 + if cnt != 2 { + t.Error("pool failed to close, consuming channel") + t.Fail() + } + + // Verify the metrics + processedCount, processedErrs := 0, 0 + for _, m := range p.Metrics() { + processedCount += m.ProcessedCount + processedErrs += m.ErrorCount + } + if processedCount != cnt || processedErrs != errs { + t.Error("consume and error callbacks failed") + t.Fail() + } +}