Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Sticky context exclude node #198

Open
wants to merge 12 commits into
base: dev-v19
Choose a base branch
from
68 changes: 61 additions & 7 deletions liteclient/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewConnectionPoolWithAuth(key ed25519.PrivateKey) *ConnectionPool {
//
// In case if sticky node goes down, default balancer will be used as fallback
func (c *ConnectionPool) StickyContext(ctx context.Context) context.Context {
if ctx.Value(_StickyCtxKey) != nil {
if c.StickyNodeID(ctx) != 0 {
return ctx
}

Expand All @@ -94,11 +94,11 @@ func (c *ConnectionPool) StickyContext(ctx context.Context) context.Context {
}
c.nodesMx.RUnlock()

return context.WithValue(ctx, _StickyCtxKey, id)
return stickyContextWithNodeID(ctx, id)
}

func (c *ConnectionPool) StickyContextNextNode(ctx context.Context) (context.Context, error) {
nodeID, _ := ctx.Value(_StickyCtxKey).(uint32)
nodeID := c.StickyNodeID(ctx)
usedNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32)
if nodeID > 0 {
usedNodes = append(usedNodes, nodeID)
Expand All @@ -115,13 +115,51 @@ iter:
}
}

return context.WithValue(context.WithValue(ctx, _StickyCtxKey, node.id), _StickyCtxUsedNodesKey, usedNodes), nil
return context.WithValue(stickyContextWithNodeID(ctx, node.id), _StickyCtxUsedNodesKey, usedNodes), nil
}

return ctx, fmt.Errorf("no more active nodes left")
}

func (c *ConnectionPool) StickyContextExcludeNode(ctx context.Context) (context.Context, error) {
nodeID := c.StickyNodeID(ctx)
if nodeID == 0 {
return ctx, fmt.Errorf("no node to exclude")
}

usedNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32)
usedNodes = append(usedNodes, nodeID)

c.nodesMx.RLock()
defer c.nodesMx.RUnlock()

if len(c.activeNodes) < len(usedNodes) {
return context.WithValue(stickyContextWithNodeID(ctx, 0), _StickyCtxUsedNodesKey, usedNodes), nil
}

return ctx, fmt.Errorf("no more active nodes left")
}

func (c *ConnectionPool) StickyContextWithNodeID(ctx context.Context, nodeId uint32) context.Context {
usedNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32)
if len(usedNodes) == 0 {
return context.WithValue(ctx, _StickyCtxKey, nodeId)
}

nodes := make([]uint32, 0, len(usedNodes))
for _, node := range usedNodes {
if node != nodeId {
nodes = append(nodes, node)
}
}
if len(nodes) == len(usedNodes) {
return stickyContextWithNodeID(ctx, nodeId)
}

return context.WithValue(stickyContextWithNodeID(ctx, nodeId), _StickyCtxUsedNodesKey, usedNodes)
}

func stickyContextWithNodeID(ctx context.Context, nodeId uint32) context.Context {
return context.WithValue(ctx, _StickyCtxKey, nodeId)
}

Expand Down Expand Up @@ -185,13 +223,14 @@ func (c *ConnectionPool) QueryADNL(ctx context.Context, request tl.Serializable,
tm := time.Now()

var node *connection
if nodeID, ok := ctx.Value(_StickyCtxKey).(uint32); ok && nodeID > 0 {
excludeNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32)
if nodeID := c.StickyNodeID(ctx); nodeID > 0 {
node, err = c.querySticky(nodeID, req)
if err != nil {
return err
}
} else {
node, err = c.queryWithSmartBalancer(req)
node, err = c.queryWithSmartBalancer(req, excludeNodes...)
if err != nil {
return err
}
Expand Down Expand Up @@ -238,11 +277,23 @@ func (c *ConnectionPool) querySticky(id uint32, req *ADNLRequest) (*connection,
return c.queryWithSmartBalancer(req)
}

func (c *ConnectionPool) queryWithSmartBalancer(req *ADNLRequest) (*connection, error) {
func (c *ConnectionPool) queryWithSmartBalancer(req *ADNLRequest, excludeNodes ...uint32) (*connection, error) {
var reqNode *connection

c.nodesMx.RLock()

if len(c.activeNodes) == 0 {
c.nodesMx.RUnlock()
return nil, ErrNoActiveConnections
}

iter:
for _, node := range c.activeNodes {
for _, excludeNode := range excludeNodes {
if node.id == excludeNode {
continue iter
}
}
if reqNode == nil {
reqNode = node
continue
Expand All @@ -256,6 +307,9 @@ func (c *ConnectionPool) queryWithSmartBalancer(req *ADNLRequest) (*connection,
c.nodesMx.RUnlock()

if reqNode == nil {
if len(excludeNodes) > 0 {
return c.queryWithSmartBalancer(req)
}
return nil, ErrNoActiveConnections
}

Expand Down
1 change: 1 addition & 0 deletions ton/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type LiteClient interface {
QueryLiteserver(ctx context.Context, payload tl.Serializable, result tl.Serializable) error
StickyContext(ctx context.Context) context.Context
StickyContextNextNode(ctx context.Context) (context.Context, error)
StickyContextExcludeNode(ctx context.Context) (context.Context, error)
StickyNodeID(ctx context.Context) uint32
}

Expand Down
6 changes: 5 additions & 1 deletion ton/retrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (w *retryClient) QueryLiteserver(ctx context.Context, payload tl.Serializab
tries++

if err != nil {
if !errors.Is(err, liteclient.ErrADNLReqTimeout) && !errors.Is(err, context.DeadlineExceeded){
if !errors.Is(err, liteclient.ErrADNLReqTimeout) && !errors.Is(err, context.DeadlineExceeded) {
return err
}

Expand Down Expand Up @@ -69,6 +69,10 @@ func (w *retryClient) StickyNodeID(ctx context.Context) uint32 {
return w.original.StickyNodeID(ctx)
}

func (w *retryClient) StickyContextExcludeNode(ctx context.Context) (context.Context, error) {
return w.original.StickyContextExcludeNode(ctx)
}

func (w *retryClient) StickyContextNextNode(ctx context.Context) (context.Context, error) {
return w.original.StickyContextNextNode(ctx)
}
4 changes: 4 additions & 0 deletions ton/timeouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ func (c *timeoutClient) StickyNodeID(ctx context.Context) uint32 {
return c.original.StickyNodeID(ctx)
}

func (w *timeoutClient) StickyContextExcludeNode(ctx context.Context) (context.Context, error) {
return w.original.StickyContextExcludeNode(ctx)
}

func (c *timeoutClient) StickyContextNextNode(ctx context.Context) (context.Context, error) {
return c.original.StickyContextNextNode(ctx)
}
4 changes: 4 additions & 0 deletions ton/waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@ func (w *waiterClient) StickyNodeID(ctx context.Context) uint32 {
func (w *waiterClient) StickyContextNextNode(ctx context.Context) (context.Context, error) {
return w.original.StickyContextNextNode(ctx)
}

func (w *waiterClient) StickyContextExcludeNode(ctx context.Context) (context.Context, error) {
return w.original.StickyContextExcludeNode(ctx)
}