diff --git a/daemon/firewall/nftables/monitor.go b/daemon/firewall/nftables/monitor.go index b8a2754b0a..88f3d62e00 100644 --- a/daemon/firewall/nftables/monitor.go +++ b/daemon/firewall/nftables/monitor.go @@ -40,7 +40,7 @@ func (n *Nft) AreRulesLoaded() bool { } } } - // we expect to have exactly 3 rules (2 queue and dns). If there're less or more, then we + // we expect to have exactly 3 rules (2 queue and 1 dns). If there're less or more, then we // need to reload them. if nRules != 3 { log.Warning("nfables filter rules not loaded: %d", nRules) diff --git a/daemon/firewall/nftables/rules.go b/daemon/firewall/nftables/rules.go index 9640acf890..966db872c2 100644 --- a/daemon/firewall/nftables/rules.go +++ b/daemon/firewall/nftables/rules.go @@ -5,6 +5,7 @@ import ( "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/log" + daemonNetlink "github.com/evilsocket/opensnitch/daemon/netlink" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -180,8 +181,14 @@ func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) { // flush conntrack as soon as netfilter rule is set. This ensures that already-established // connections will go to netfilter queue. if err := netlink.ConntrackTableFlush(netlink.ConntrackTable); err != nil { - log.Error("nftables, error in ConntrackTableFlush %s", err) + log.Error("nftables, error flushing ConntrackTable %s", err) } + if err := netlink.ConntrackTableFlush(netlink.ConntrackExpectTable); err != nil { + log.Error("nftables, error flusing ConntrackExpectTable %s", err) + } + + // Force established connections to reestablish again. + daemonNetlink.KillAllSockets() } return nil, nil diff --git a/daemon/netlink/socket.go b/daemon/netlink/socket.go index 9492c1e3e4..68c27c162f 100644 --- a/daemon/netlink/socket.go +++ b/daemon/netlink/socket.go @@ -7,6 +7,7 @@ import ( "syscall" "github.com/evilsocket/opensnitch/daemon/log" + "golang.org/x/sys/unix" ) // GetSocketInfo asks the kernel via netlink for a given connection. @@ -151,13 +152,58 @@ func KillSocket(proto string, srcIP net.IP, srcPort uint, dstIP net.IP, dstPort if sockList, err := SocketGet(family, ipproto, uint16(srcPort), uint16(dstPort), srcIP, dstIP); err == nil { for _, s := range sockList { - if err := socketKill(family, ipproto, s.ID); err != nil { + if err := SocketKill(family, ipproto, s.ID); err != nil { log.Debug("Unable to kill socket: %d, %d, %v", srcPort, dstPort, err) } } } } +// KillSockets kills all sockets given a family and a protocol. +// Be careful if you don't exclude local sockets, many local servers may misbehave, +// entering in an infinite loop. +func KillSockets(fam, proto uint8, excludeLocal bool) error { + sockListTCP, err := SocketsDump(fam, proto) + if err != nil { + return fmt.Errorf("eBPF could not dump TCP (%d/%d) sockets via netlink: %v", fam, proto, err) + } + + for _, sock := range sockListTCP { + if excludeLocal && (sock.ID.Destination.IsPrivate() || + sock.ID.Source.IsUnspecified() || + sock.ID.Destination.IsUnspecified()) { + continue + } + log.Error("KILLINGIT: %+v", sock.ID) + if err := SocketKill(fam, proto, sock.ID); err != nil { + log.Error("ERRORERRORERROR KILLING: %s", err) + } + } + + return nil +} + +// KillAllSockets kills the sockets for the given families and protocols. +func KillAllSockets() { + type opts struct { + fam uint8 + proto uint8 + } + optList := []opts{ + // add families and protos as wish + {unix.AF_INET, uint8(syscall.IPPROTO_TCP)}, + {unix.AF_INET6, uint8(syscall.IPPROTO_TCP)}, + {unix.AF_INET, uint8(syscall.IPPROTO_UDP)}, + {unix.AF_INET6, uint8(syscall.IPPROTO_UDP)}, + {unix.AF_INET, uint8(syscall.IPPROTO_SCTP)}, + {unix.AF_INET6, uint8(syscall.IPPROTO_SCTP)}, + } + for _, opt := range optList { + KillSockets(opt.fam, opt.proto, true) + } + +} + // SocketsAreEqual compares 2 different sockets to see if they match. func SocketsAreEqual(aSocket, bSocket *Socket) bool { return ((*aSocket).INode == (*bSocket).INode && diff --git a/daemon/netlink/socket_linux.go b/daemon/netlink/socket_linux.go index 944f278c21..caf1d53360 100644 --- a/daemon/netlink/socket_linux.go +++ b/daemon/netlink/socket_linux.go @@ -188,7 +188,7 @@ func (s *Socket) deserialize(b []byte) error { } // SocketKill kills a connection -func socketKill(family, proto uint8, sockID SocketID) error { +func SocketKill(family, proto uint8, sockID SocketID) error { sockReq := &SocketRequest{ Family: family,