Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[client, relay] Fix/wg watch #3261

Merged
merged 14 commits into from
Feb 10, 2025
49 changes: 17 additions & 32 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ const (
defaultWgKeepAlive = 25 * time.Second

connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2
connPriorityICETurn ConnPriority = 2
connPriorityICEP2P ConnPriority = 3
)

type WgConfig struct {
Expand Down Expand Up @@ -66,14 +66,6 @@ type ConnConfig struct {
ICEConfig icemaker.Config
}

type WorkerCallbacks struct {
OnRelayReadyCallback func(info RelayConnInfo)
OnRelayStatusChanged func(ConnStatus)

OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
OnICEStatusChanged func(ConnStatus)
}

type Conn struct {
log *log.Entry
mu sync.Mutex
Expand Down Expand Up @@ -135,21 +127,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
semaphore: semaphore,
}

rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
}

wFns := WorkerICECallbacks{
OnConnReady: conn.iCEConnectionIsReady,
OnStatusChanged: conn.onWorkerICEStateDisconnected,
}

ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)

relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -304,7 +286,7 @@ func (conn *Conn) GetKey() string {
}

// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock()
defer conn.mu.Unlock()

Expand Down Expand Up @@ -376,15 +358,15 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
}

// todo review to make sense to handle connecting and disconnected status also?
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
func (conn *Conn) onICEStateDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()

if conn.ctx.Err() != nil {
return
}

conn.log.Tracef("ICE connection state changed to %s", newState)
conn.log.Tracef("ICE connection state changed to disconnected")

if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
Expand All @@ -404,10 +386,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay
}

changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)

conn.guard.SetICEConnDisconnected(changed)
changed := conn.statusICE.Get() != StatusDisconnected
if changed {
conn.guard.SetICEConnDisconnected()
}
conn.statusICE.Set(StatusDisconnected)

peerState := State{
PubKey: conn.config.Key,
Expand All @@ -422,7 +405,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
}
}

func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.mu.Lock()
defer conn.mu.Unlock()

Expand Down Expand Up @@ -474,7 +457,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
}

func (conn *Conn) onWorkerRelayStateDisconnected() {
func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()

Expand All @@ -497,8 +480,10 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
}

changed := conn.statusRelay.Get() != StatusDisconnected
if changed {
conn.guard.SetRelayedConnDisconnected()
}
conn.statusRelay.Set(StatusDisconnected)
conn.guard.SetRelayedConnDisconnected(changed)

peerState := State{
PubKey: conn.config.Key,
Expand Down
36 changes: 12 additions & 24 deletions client/internal/peer/guard/guard.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type Guard struct {
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan bool
iCEConnDisconnected chan bool
relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{}
}

func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
Expand All @@ -41,8 +41,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
srWatcher: srWatcher,
relayedConnDisconnected: make(chan bool, 1),
iCEConnDisconnected: make(chan bool, 1),
relayedConnDisconnected: make(chan struct{}, 1),
iCEConnDisconnected: make(chan struct{}, 1),
}
}

Expand All @@ -54,16 +54,16 @@ func (g *Guard) Start(ctx context.Context) {
}
}

func (g *Guard) SetRelayedConnDisconnected(changed bool) {
func (g *Guard) SetRelayedConnDisconnected() {
select {
case g.relayedConnDisconnected <- changed:
case g.relayedConnDisconnected <- struct{}{}:
default:
}
}

func (g *Guard) SetICEConnDisconnected(changed bool) {
func (g *Guard) SetICEConnDisconnected() {
select {
case g.iCEConnDisconnected <- changed:
case g.iCEConnDisconnected <- struct{}{}:
default:
}
}
Expand Down Expand Up @@ -96,19 +96,13 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
g.triggerOfferSending()
}

case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C

case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
Expand Down Expand Up @@ -138,16 +132,10 @@ func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
g.log.Infof("start listen for reconnect events...")
for {
select {
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
Expand Down
154 changes: 154 additions & 0 deletions client/internal/peer/wg_watcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package peer

import (
"context"
"sync"
"time"

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/iface/configurer"
)

const (
wgHandshakePeriod = 3 * time.Minute
)

var (
wgHandshakeOvertime = 30 * time.Second // allowed delay in network
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
)

type WGInterfaceStater interface {
GetStats(key string) (configurer.WGStats, error)
}

type WGWatcher struct {
log *log.Entry
wgIfaceStater WGInterfaceStater
peerKey string

ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
waitGroup sync.WaitGroup
}

func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
return &WGWatcher{
log: log,
wgIfaceStater: wgIfaceStater,
peerKey: peerKey,
}
}

// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
defer w.ctxLock.Unlock()

if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
return
}

ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel

initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}

w.waitGroup.Add(1)
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
}

// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()

if w.ctxCancel == nil {
return
}

w.log.Debugf("disable WireGuard watcher")

w.ctxCancel()
w.ctxCancel = nil
w.waitGroup.Wait()
}

// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
defer w.waitGroup.Done()

timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()

lastHandshake := initialHandshake

for {
select {
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
onDisconnectedFn()
return
}
lastHandshake = *handshake

resetTime := time.Until(handshake.Add(checkPeriod))
timer.Reset(resetTime)

w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped")
return
}
}
}

// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
return nil, false
}

w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)

// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}

// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}

// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
return nil, false
}

return &handshake, true
}

func (w *WGWatcher) wgState() (time.Time, error) {
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
if err != nil {
return time.Time{}, err
}
return wgState.LastHandshake, nil
}
Loading
Loading