diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 84f1b1b7523..f5f864ead94 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -12,7 +12,7 @@ import ( ) var logger = log.NewFromLogrus(logrus.StandardLogger()) -var flowLogger = netflow.NewManager(context.Background()).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d9576a1c02c..ad6d430b859 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -24,7 +24,7 @@ import ( ) var logger = log.NewFromLogrus(logrus.StandardLogger()) -var flowLogger = netflow.NewManager(context.Background()).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() type IFaceMock struct { SetFilterFunc func(device.PacketFilter) error diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 04da0e7d5a2..bce85034740 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -15,7 +15,7 @@ import ( mgmProto "github.com/netbirdio/netbird/management/proto" ) -var flowLogger = netflow.NewManager(context.Background()).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() func TestDefaultManager(t *testing.T) { networkMap := &mgmProto.NetworkMap{ diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 4741b9e1dea..853bc9b9c16 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -30,7 +30,7 @@ import ( "github.com/netbirdio/netbird/formatter" ) -var flowLogger = netflow.NewManager(context.Background()).GetLogger() +var flowLogger = netflow.NewManager(context.Background(), []byte{}).GetLogger() type mocWGIface struct { filter device.PacketFilter diff --git a/client/internal/engine.go b/client/internal/engine.go index 16eeff1c9d0..d1f87882077 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -216,6 +216,7 @@ func NewEngine( statusRecorder *peer.Status, checks []*mgmProto.Checks, ) *Engine { + publicKey := config.WgPrivateKey.PublicKey() engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, @@ -234,7 +235,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - flowManager: netflow.NewManager(clientCtx), + flowManager: netflow.NewManager(clientCtx, publicKey[:]), } if runtime.GOOS == "ios" { if !fileExists(mobileDep.StateFilePath) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 640f87c4939..1e23c1dceb6 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -107,6 +107,10 @@ func (l *Logger) GetEvents() []*types.Event { return l.Store.GetEvents() } +func (l *Logger) DeleteEvents(ids []string) { + l.Store.DeleteEvents(ids) +} + func (l *Logger) Close() { l.stop() l.cancel() diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index b3f2594eb6c..527dfd25690 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -3,38 +3,62 @@ package netflow import ( "context" "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/client/internal/netflow/logger" "github.com/netbirdio/netbird/client/internal/netflow/types" + "github.com/netbirdio/netbird/flow/client" + "github.com/netbirdio/netbird/flow/proto" ) type Manager struct { - mux sync.Mutex - logger types.FlowLogger - flowConfig *types.FlowConfig + mux sync.Mutex + logger types.FlowLogger + flowConfig *types.FlowConfig + ctx context.Context + receiverClient *client.GRPCClient + publicKey []byte } -func NewManager(ctx context.Context) *Manager { +func NewManager(ctx context.Context, publicKey []byte) *Manager { return &Manager{ - logger: logger.New(ctx), + logger: logger.New(ctx), + ctx: ctx, + publicKey: publicKey, } } func (m *Manager) Update(update *types.FlowConfig) error { - m.mux.Lock() - defer m.mux.Unlock() if update == nil { return nil } - + m.mux.Lock() + defer m.mux.Unlock() + previous := m.flowConfig m.flowConfig = update if update.Enabled { m.logger.Enable() + if previous == nil || !previous.Enabled { + flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature) + if err != nil { + return err + } + log.Infof("flow client connected to %s", m.flowConfig.URL) + m.receiverClient = flowClient + go m.receiveACKs() + go m.startSender() + } return nil } m.logger.Disable() + if previous != nil && previous.Enabled { + return m.receiverClient.Close() + } return nil } @@ -46,3 +70,78 @@ func (m *Manager) Close() { func (m *Manager) GetLogger() types.FlowLogger { return m.logger } + +func (m *Manager) startSender() { + ticker := time.NewTicker(m.flowConfig.Interval) + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + events := m.logger.GetEvents() + for _, event := range events { + log.Infof("send flow event to server: %s", event.ID) + err := m.send(event) + if err != nil { + log.Errorf("send flow event to server: %s", err) + } + } + } + } +} + +func (m *Manager) receiveACKs() { + if m.receiverClient == nil { + return + } + err := m.receiverClient.Receive(m.ctx, func(ack *proto.FlowEventAck) error { + log.Infof("receive flow event ack: %s", ack.EventId) + m.logger.DeleteEvents([]string{ack.EventId}) + return nil + }) + if err != nil { + log.Errorf("receive flow event ack: %s", err) + } +} + +func (m *Manager) send(event *types.Event) error { + if m.receiverClient == nil { + return nil + } + return m.receiverClient.Send(m.ctx, toProtoEvent(m.publicKey, event)) +} + +func toProtoEvent(publicKey []byte, event *types.Event) *proto.FlowEvent { + protoEvent := &proto.FlowEvent{ + EventId: event.ID, + FlowId: event.FlowID.String(), + Timestamp: timestamppb.New(event.Timestamp), + PublicKey: publicKey, + EventFields: &proto.EventFields{ + Type: proto.Type(event.Type), + Direction: proto.Direction(event.Direction), + Protocol: uint32(event.Protocol), + SourceIp: event.SourceIP.AsSlice(), + DestIp: event.DestIP.AsSlice(), + }, + } + if event.Protocol == 1 { + protoEvent.EventFields.ConnectionInfo = &proto.EventFields_IcmpInfo{ + IcmpInfo: &proto.ICMPInfo{ + IcmpType: uint32(event.ICMPType), + IcmpCode: uint32(event.ICMPCode), + }, + } + return protoEvent + } + + protoEvent.EventFields.ConnectionInfo = &proto.EventFields_PortInfo{ + PortInfo: &proto.PortInfo{ + SourcePort: uint32(event.SourcePort), + DestPort: uint32(event.DestPort), + }, + } + + return protoEvent +} diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index dd4b60889bd..02c8135feee 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -97,6 +97,8 @@ type FlowLogger interface { StoreEvent(flowEvent EventFields) // GetEvents returns all stored events GetEvents() []*Event + // DeleteEvents deletes events from the store + DeleteEvents([]string) // Close closes the logger Close() // Enable enables the flow logger receiver diff --git a/flow/client/auth.go b/flow/client/auth.go new file mode 100644 index 00000000000..de9e9cece2d --- /dev/null +++ b/flow/client/auth.go @@ -0,0 +1,32 @@ +package client + +import ( + "context" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +var _ credentials.PerRPCCredentials = (*authToken)(nil) + +type authToken struct { + metaMap map[string]string +} + +func (t authToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return t.metaMap, nil +} + +func (authToken) RequireTransportSecurity() bool { + return false // Set to true if you want to require a secure connection +} + +// WithAuthToken returns a DialOption which sets the receiver flow credentials and places auth state on each outbound RPC +func withAuthToken(payload, signature string) grpc.DialOption { + value := fmt.Sprintf("%s.%s", signature, payload) + authMap := map[string]string{ + "authorization": "Bearer " + value, + } + return grpc.WithPerRPCCredentials(authToken{metaMap: authMap}) +} diff --git a/flow/client/client.go b/flow/client/client.go new file mode 100644 index 00000000000..47c80ef0d78 --- /dev/null +++ b/flow/client/client.go @@ -0,0 +1,158 @@ +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "strings" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/flow/proto" + "github.com/netbirdio/netbird/util/embeddedroots" + nbgrpc "github.com/netbirdio/netbird/util/grpc" +) + +type GRPCClient struct { + realClient proto.FlowServiceClient + clientConn *grpc.ClientConn + stream proto.FlowService_EventsClient +} + +func NewClient(ctx context.Context, addr, payload, signature string) (*GRPCClient, error) { + + transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + if strings.Contains(addr, "443") { + + certPool, err := x509.SystemCertPool() + if err != nil || certPool == nil { + log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) + certPool = embeddedroots.Get() + } + + transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: certPool, + })) + } + + connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + connCtx, + addr, + transportOption, + nbgrpc.WithCustomDialer(), + grpc.WithBlock(), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + }), + withAuthToken(payload, signature), + ) + + if err != nil { + return nil, fmt.Errorf("dialing with context: %s", err) + } + + client := &GRPCClient{ + realClient: proto.NewFlowServiceClient(conn), + clientConn: conn, + } + return client, nil +} + +func (c *GRPCClient) Close() error { + return c.clientConn.Close() +} + +func (c *GRPCClient) Receive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error { + backOff := defaultBackoff(ctx) + operation := func() error { + connState := c.clientConn.GetState() + if connState == connectivity.Shutdown { + return backoff.Permanent(fmt.Errorf("connection to signal has been shut down")) + } + + stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true)) + if err != nil { + return err + } + c.stream = stream + + err = checkHeader(stream) + if err != nil { + return err + } + + return c.receive(stream, msgHandler) + } + + err := backoff.Retry(operation, backOff) + if err != nil { + log.Errorf("exiting the flow receiver service connection retry loop due to the unrecoverable error: %v", err) + return err + } + + return nil +} + +func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + + if err := msgHandler(msg); err != nil { + return err + } + } +} + +func checkHeader(stream proto.FlowService_EventsClient) error { + header, err := stream.Header() + if err != nil { + log.Errorf("waiting for flow receiver header: %s", err) + return err + } + + if len(header) == 0 { + log.Error("flow receiver sent no headers") + return fmt.Errorf("should have headers") + } + return nil +} + +func defaultBackoff(ctx context.Context) backoff.BackOff { + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 1, + Multiplier: 1.7, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +func (c *GRPCClient) Send(ctx context.Context, event *proto.FlowEvent) error { + if c.stream == nil { + return fmt.Errorf("stream not initialized") + } + + err := c.stream.Send(event) + if err != nil { + return fmt.Errorf("sending flow event: %s", err) + } + + return nil +}