Skip to content

Commit

Permalink
test for thread.pool
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerreeves committed May 14, 2022
1 parent 13618f2 commit bb20704
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
19 changes: 12 additions & 7 deletions thread/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type Pool[K any] struct {
Consume func(elem *K) error
OnError func(elem *K, e error)
Closed bool
closeCh chan struct{}
waitGroup sync.WaitGroup
workers []*Thread
}
Expand All @@ -21,31 +22,35 @@ func NewPool[K any](count int, consumerChan chan K, consumerFn func(elem *K) err
Consume: consumerFn,
OnError: onError,
Closed: false,
closeCh: make(chan struct{}),
}

p.waitGroup.Add(count)
for i := 0; i < count; i++ {
p.workers = append(p.workers, Consumer[K](&p.waitGroup, consumerChan, consumerFn, onError))
p.workers = append(p.workers, Consumer[K](&p.waitGroup, p.closeCh, consumerChan, consumerFn, onError))
}

return &p
}

// Close Has side effects! Closes the channel and optionally waits for threads to indicate they have closed.
func (p *Pool[K]) Close(block bool) error {
if !p.Closed {
close(p.ConsumerChannel)
if p.Closed {
return nil
} else {
for i := 0; i < p.Count; i++ {
p.closeCh <- struct{}{}
}
}

p.Closed = true
if block {
p.waitGroup.Wait()
close(p.closeCh)
}

p.Closed = true
return nil
}

func (p *Pool[K]) Metrics(aggregated bool) []*Metric {
func (p *Pool[K]) Metrics() []*Metric {
var metrics []*Metric
for _, w := range p.workers {
metrics = append(metrics, w.Metrics)
Expand Down
44 changes: 44 additions & 0 deletions thread/pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package thread_test

import (
"github.com/spencerreeves/snippets/thread"
"testing"
"time"
)

func TestNewPool(t *testing.T) {
cnt, errs, ch := 0, 0, make(chan int, 1)

// Verify we can create a pool
p := thread.NewPool[int](2, ch, incFn(&cnt), errCntFn(&errs))

// Verify pool consumes data
ch <- 1
ch <- 1
time.Sleep(time.Millisecond)
if cnt != 2 {
t.Error("invalid count")
}

// Verify all threads are closed when we close a pool
if err := p.Close(true); err != nil || !p.Closed {
t.Error("pool failed to close")
}

ch <- 1
if cnt != 2 {
t.Error("pool failed to close, consuming channel")
t.Fail()
}

// Verify the metrics
processedCount, processedErrs := 0, 0
for _, m := range p.Metrics() {
processedCount += m.ProcessedCount
processedErrs += m.ErrorCount
}
if processedCount != cnt || processedErrs != errs {
t.Error("consume and error callbacks failed")
t.Fail()
}
}

0 comments on commit bb20704

Please # to comment.