diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 679f288e32a..929e8a6565a 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -218,6 +218,14 @@ func (m *Manager) SetLogLevel(log.Level) { // not supported } +func (m *Manager) EnableRouting() error { + return nil +} + +func (m *Manager) DisableRouting() error { + return nil +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index de25ff1f11c..d007e20a51c 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -101,6 +101,10 @@ type Manager interface { Flush() error SetLogLevel(log.Level) + + EnableRouting() error + + DisableRouting() error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 4fe52bd5361..de68f329156 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -323,6 +323,14 @@ func (m *Manager) SetLogLevel(log.Level) { // not supported } +func (m *Manager) EnableRouting() error { + return nil +} + +func (m *Manager) DisableRouting() error { + return nil +} + // Flush rule/chain/set operations from the buffer // // Method also get all rules after flush and refreshes handle values in the rulesets diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 889e4cbb1a9..55e4a174e6a 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -74,6 +74,8 @@ type Manager struct { mutex sync.RWMutex + // indicates whether we server routes are disabled + disableServerRoutes bool // indicates whether we forward packets not destined for ourselves routingEnabled bool // indicates whether we leave forwarding and filtering to the native firewall @@ -149,15 +151,16 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe return d }, }, - nativeFirewall: nativeFirewall, - outgoingRules: make(map[string]RuleSet), - incomingRules: make(map[string]RuleSet), - wgIface: iface, - localipmanager: newLocalIPManager(), - routingEnabled: false, - stateful: !disableConntrack, - logger: nblog.NewFromLogrus(log.StandardLogger()), - netstack: netstack.IsEnabled(), + nativeFirewall: nativeFirewall, + outgoingRules: make(map[string]RuleSet), + incomingRules: make(map[string]RuleSet), + wgIface: iface, + localipmanager: newLocalIPManager(), + disableServerRoutes: disableServerRoutes, + routingEnabled: false, + stateful: !disableConntrack, + logger: nblog.NewFromLogrus(log.StandardLogger()), + netstack: netstack.IsEnabled(), // default true for non-netstack, for netstack only if explicitly enabled localForwarding: !netstack.IsEnabled() || enableLocalForwarding, } @@ -166,7 +169,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe return nil, fmt.Errorf("update local IPs: %w", err) } - // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") } else { @@ -175,7 +177,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } - m.determineRouting(iface, disableServerRoutes) + // netstack needs the forwarder for local traffic + if m.netstack && m.localForwarding { + if err := m.initForwarder(iface); err != nil { + log.Errorf("failed to initialize forwarder: %v", err) + } + } if err := m.blockInvalidRouted(iface); err != nil { log.Errorf("failed to block invalid routed traffic: %v", err) @@ -213,9 +220,21 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { return nil } -func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) { - disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)) - forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)) +func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) error { + var disableUspRouting, forceUserspaceRouter bool + var err error + if val := os.Getenv(EnvDisableUserspaceRouting); val != "" { + disableUspRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err) + } + } + if val := os.Getenv(EnvForceUserspaceRouter); val != "" { + forceUserspaceRouter, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err) + } + } switch { case disableUspRouting: @@ -252,32 +271,37 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes log.Info("userspace routing enabled by default") } - // netstack needs the forwarder for local traffic - if m.netstack && m.localForwarding || - m.routingEnabled && !m.nativeRouter { - - m.initForwarder(iface) + if m.routingEnabled && !m.nativeRouter { + return m.initForwarder(iface) } + + return nil } // initForwarder initializes the forwarder, it disables routing on errors -func (m *Manager) initForwarder(iface common.IFaceMapper) { +func (m *Manager) initForwarder(iface common.IFaceMapper) error { + if m.forwarder != nil { + return nil + } + // Only supported in userspace mode as we need to inject packets back into wireguard directly intf := iface.GetWGDevice() if intf == nil { - log.Info("forwarding not supported") m.routingEnabled = false - return + return errors.New("forwarding not supported") } forwarder, err := forwarder.New(iface, m.logger, m.netstack) if err != nil { - log.Errorf("failed to create forwarder: %v", err) m.routingEnabled = false - return + return fmt.Errorf("create forwarder: %w", err) } m.forwarder = forwarder + + log.Debug("forwarder initialized") + + return nil } func (m *Manager) Init(*statemanager.Manager) error { @@ -285,7 +309,7 @@ func (m *Manager) Init(*statemanager.Manager) error { } func (m *Manager) IsServerRouteSupported() bool { - return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil + return true } func (m *Manager) AddNatRule(pair firewall.RouterPair) error { @@ -953,3 +977,34 @@ func (m *Manager) SetLogLevel(level log.Level) { m.logger.SetLevel(nblog.Level(level)) } } + +func (m *Manager) EnableRouting() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.determineRouting(m.wgIface, m.disableServerRoutes) +} + +func (m *Manager) DisableRouting() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.forwarder == nil { + return nil + } + + m.routingEnabled = false + m.nativeRouter = false + + // don't stop forwarder if in use by netstack + if m.netstack && m.localForwarding { + return nil + } + + m.forwarder.Stop() + m.forwarder = nil + + log.Debug("Forwarder stopped") + + return nil +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9c7f1f6faea..61b8dab35b9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -286,15 +286,25 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.updateClientNetworks(updateSerial, filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes) } + m.clientRoutes = newClientRoutesIDMap - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return err + if m.serverRouter == nil { + return nil + } + + if len(newServerRoutesMap) > 0 { + if err := m.serverRouter.EnableRouting(); err != nil { + return fmt.Errorf("enable routing: %w", err) + } + } else { + if err := m.serverRouter.DisableRouting(); err != nil { + return fmt.Errorf("disable routing: %w", err) } } - m.clientRoutes = newClientRoutesIDMap + if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { + return fmt.Errorf("update routes: %w", err) + } return nil } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index b60cb318e67..e676981dff2 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -167,6 +167,14 @@ func (m *serverRouter) cleanUp() { m.statusRecorder.UpdateLocalPeerState(state) } +func (r *serverRouter) EnableRouting() error { + return r.firewall.EnableRouting() +} + +func (r *serverRouter) DisableRouting() error { + return r.firewall.DisableRouting() +} + func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { // TODO: add ipv6 source := getDefaultPrefix(route.Network)