-
-
Notifications
You must be signed in to change notification settings - Fork 620
/
Copy pathaccount_request_buffer.go
111 lines (93 loc) · 3.04 KB
/
account_request_buffer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package server
import (
"context"
"os"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// AccountRequest holds the result channel to return the requested account.
type AccountRequest struct {
AccountID string
ResultChan chan *AccountResult
}
// AccountResult holds the account data or an error.
type AccountResult struct {
Account *types.Account
Err error
}
type AccountRequestBuffer struct {
store store.Store
getAccountRequests map[string][]*AccountRequest
mu sync.Mutex
getAccountRequestCh chan *AccountRequest
bufferInterval time.Duration
}
func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
if err != nil {
if bufferIntervalStr != "" {
log.WithContext(ctx).Warnf("failed to parse account request buffer interval: %s", err)
}
bufferInterval = 100 * time.Millisecond
}
log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)
ac := AccountRequestBuffer{
store: store,
getAccountRequests: make(map[string][]*AccountRequest),
getAccountRequestCh: make(chan *AccountRequest),
bufferInterval: bufferInterval,
}
go ac.processGetAccountRequests(ctx)
return &ac
}
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
req := &AccountRequest{
AccountID: accountID,
ResultChan: make(chan *AccountResult, 1),
}
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountRequestCh <- req
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
ac.mu.Lock()
requests := ac.getAccountRequests[accountID]
delete(ac.getAccountRequests, accountID)
ac.mu.Unlock()
if len(requests) == 0 {
return
}
startTime := time.Now()
account, err := ac.store.GetAccount(ctx, accountID)
log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
result := &AccountResult{Account: account, Err: err}
for _, req := range requests {
req.ResultChan <- result
close(req.ResultChan)
}
}
func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
for {
select {
case req := <-ac.getAccountRequestCh:
ac.mu.Lock()
ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
if len(ac.getAccountRequests[req.AccountID]) == 1 {
go func(ctx context.Context, accountID string) {
time.Sleep(ac.bufferInterval)
ac.processGetAccountBatch(ctx, accountID)
}(ctx, req.AccountID)
}
ac.mu.Unlock()
case <-ctx.Done():
return
}
}
}