From 921e14329ca721b1894e30c9a3ea45dbddd6f933 Mon Sep 17 00:00:00 2001 From: "davidzhang.zc" Date: Mon, 31 May 2021 11:51:55 +0800 Subject: [PATCH] [fix] consumer map race [fix] producer invalid task panic --- consumer/worker.go | 80 ++++++++++++++++++++++---------------- producer/io_thread_pool.go | 15 ++++--- producer/io_worker.go | 17 ++++---- 3 files changed, 64 insertions(+), 48 deletions(-) diff --git a/consumer/worker.go b/consumer/worker.go index bc149c94..2a1d8ccc 100644 --- a/consumer/worker.go +++ b/consumer/worker.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/aliyun/aliyun-log-go-sdk" + sls "github.com/aliyun/aliyun-log-go-sdk" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "go.uber.org/atomic" @@ -16,7 +16,7 @@ type ConsumerWorker struct { consumerHeatBeat *ConsumerHeatBeat client *ConsumerClient workerShutDownFlag *atomic.Bool - shardConsumer map[int]*ShardConsumerWorker + shardConsumer sync.Map // map[int]*ShardConsumerWorker do func(shard int, logGroup *sls.LogGroupList) string waitGroup sync.WaitGroup Logger log.Logger @@ -30,9 +30,9 @@ func InitConsumerWorker(option LogHubConfig, do func(int, *sls.LogGroupList) str consumerHeatBeat: consumerHeatBeat, client: consumerClient, workerShutDownFlag: atomic.NewBool(false), - shardConsumer: make(map[int]*ShardConsumerWorker), - do: do, - Logger: logger, + //shardConsumer: make(map[int]*ShardConsumerWorker), + do: do, + Logger: logger, } consumerClient.createConsumerGroup() return consumerWorker @@ -82,14 +82,20 @@ func (consumerWorker *ConsumerWorker) run() { func (consumerWorker *ConsumerWorker) shutDownAndWait() { for { time.Sleep(500 * time.Millisecond) - for shard, consumer := range consumerWorker.shardConsumer { - if !consumer.isShutDownComplete() { - consumer.consumerShutDown() - } else if consumer.isShutDownComplete() { - delete(consumerWorker.shardConsumer, shard) - } - } - if len(consumerWorker.shardConsumer) == 0 { + count := 0 + consumerWorker.shardConsumer.Range( + func(key, value interface{}) bool { + count++ + consumer := value.(*ShardConsumerWorker) + if !consumer.isShutDownComplete() { + consumer.consumerShutDown() + } else if consumer.isShutDownComplete() { + consumerWorker.shardConsumer.Delete(key) + } + return true + }, + ) + if count == 0 { break } } @@ -97,35 +103,41 @@ func (consumerWorker *ConsumerWorker) shutDownAndWait() { } func (consumerWorker *ConsumerWorker) getShardConsumer(shardId int) *ShardConsumerWorker { - consumer := consumerWorker.shardConsumer[shardId] - if consumer != nil { - return consumer + consumer, ok := consumerWorker.shardConsumer.Load(shardId) + if ok { + return consumer.(*ShardConsumerWorker) } - consumer = initShardConsumerWorker(shardId, consumerWorker.client, consumerWorker.do, consumerWorker.Logger) - consumerWorker.shardConsumer[shardId] = consumer - return consumer + consumerIns := initShardConsumerWorker(shardId, consumerWorker.client, consumerWorker.do, consumerWorker.Logger) + consumerWorker.shardConsumer.Store(shardId, consumerIns) + return consumerIns } func (consumerWorker *ConsumerWorker) cleanShardConsumer(owned_shards []int) { - for shard, consumer := range consumerWorker.shardConsumer { - if !Contain(shard, owned_shards) { - level.Info(consumerWorker.Logger).Log("msg", "try to call shut down for unassigned consumer shard", "shardId", shard) - consumer.consumerShutDown() - level.Info(consumerWorker.Logger).Log("msg", "Complete call shut down for unassigned consumer shard", "shardId", shard) - } + consumerWorker.shardConsumer.Range( + func(key, value interface{}) bool { + shard := key.(int) + consumer := value.(*ShardConsumerWorker) - if consumer.isShutDownComplete() { - isDeleteShard := consumerWorker.consumerHeatBeat.removeHeartShard(shard) - if isDeleteShard { - level.Info(consumerWorker.Logger).Log("msg", "Remove an assigned consumer shard", "shardId", shard) - delete(consumerWorker.shardConsumer, shard) - } else { - level.Info(consumerWorker.Logger).Log("msg", "Remove an assigned consumer shard failed", "shardId", shard) + if !Contain(shard, owned_shards) { + level.Info(consumerWorker.Logger).Log("msg", "try to call shut down for unassigned consumer shard", "shardId", shard) + consumer.consumerShutDown() + level.Info(consumerWorker.Logger).Log("msg", "Complete call shut down for unassigned consumer shard", "shardId", shard) } - } - } + + if consumer.isShutDownComplete() { + isDeleteShard := consumerWorker.consumerHeatBeat.removeHeartShard(shard) + if isDeleteShard { + level.Info(consumerWorker.Logger).Log("msg", "Remove an assigned consumer shard", "shardId", shard) + consumerWorker.shardConsumer.Delete(shard) + } else { + level.Info(consumerWorker.Logger).Log("msg", "Remove an assigned consumer shard failed", "shardId", shard) + } + } + return true + }, + ) } // This function is used to initialize the global log configuration diff --git a/producer/io_thread_pool.go b/producer/io_thread_pool.go index 24803145..3317f3db 100644 --- a/producer/io_thread_pool.go +++ b/producer/io_thread_pool.go @@ -36,6 +36,9 @@ func (threadPool *IoThreadPool) addTask(batch *ProducerBatch) { func (threadPool *IoThreadPool) popTask() *ProducerBatch { defer threadPool.lock.Unlock() threadPool.lock.Lock() + if threadPool.queue.Len() <= 0 { + return nil + } ele := threadPool.queue.Front() threadPool.queue.Remove(ele) return ele.Value.(*ProducerBatch) @@ -50,12 +53,12 @@ func (threadPool *IoThreadPool) hasTask() bool { func (threadPool *IoThreadPool) start(ioWorkerWaitGroup *sync.WaitGroup, ioThreadPoolwait *sync.WaitGroup) { defer ioThreadPoolwait.Done() for { - if threadPool.hasTask() { - select { - case threadPool.ioworker.maxIoWorker <- 1: - ioWorkerWaitGroup.Add(1) - go threadPool.ioworker.sendToServer(threadPool.popTask(), ioWorkerWaitGroup) - } + if task := threadPool.popTask(); task != nil { + threadPool.ioworker.startSendTask(ioWorkerWaitGroup) + go func(producerBatch *ProducerBatch) { + defer threadPool.ioworker.closeSendTask(ioWorkerWaitGroup) + threadPool.ioworker.sendToServer(producerBatch) + }(task) } else { if !threadPool.threadPoolShutDownFlag.Load() { time.Sleep(100 * time.Millisecond) diff --git a/producer/io_worker.go b/producer/io_worker.go index f2e33d4d..21f24d7b 100644 --- a/producer/io_worker.go +++ b/producer/io_worker.go @@ -41,14 +41,9 @@ func initIoWorker(client sls.ClientInterface, retryQueue *RetryQueue, logger log } } -func (ioWorker *IoWorker) sendToServer(producerBatch *ProducerBatch, ioWorkerWaitGroup *sync.WaitGroup) { - if producerBatch == nil || ioWorkerWaitGroup == nil { - return - } +func (ioWorker *IoWorker) sendToServer(producerBatch *ProducerBatch) { level.Debug(ioWorker.logger).Log("msg", "ioworker send data to server") - defer ioWorker.closeSendTask(ioWorkerWaitGroup) var err error - atomic.AddInt64(&ioWorker.taskCount, 1) if producerBatch.shardHash != nil { err = ioWorker.client.PostLogStoreLogs(producerBatch.getProject(), producerBatch.getLogstore(), producerBatch.logGroup, producerBatch.getShardHash()) } else { @@ -116,9 +111,15 @@ func (ioWorker *IoWorker) addErrorMessageToBatchAttempt(producerBatch *ProducerB } func (ioWorker *IoWorker) closeSendTask(ioWorkerWaitGroup *sync.WaitGroup) { - ioWorkerWaitGroup.Done() - atomic.AddInt64(&ioWorker.taskCount, -1) <-ioWorker.maxIoWorker + atomic.AddInt64(&ioWorker.taskCount, -1) + ioWorkerWaitGroup.Done() +} + +func (ioWorker *IoWorker) startSendTask(ioWorkerWaitGroup *sync.WaitGroup) { + atomic.AddInt64(&ioWorker.taskCount, 1) + ioWorker.maxIoWorker <- 1 + ioWorkerWaitGroup.Add(1) } func (ioWorker *IoWorker) excuteFailedCallback(producerBatch *ProducerBatch) {