diff --git a/lib/auth/server_info.go b/lib/auth/server_info.go index ec958d82112ea..abee67875e071 100644 --- a/lib/auth/server_info.go +++ b/lib/auth/server_info.go @@ -20,57 +20,99 @@ package auth import ( "context" + "log/slog" "maps" "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/defaults" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" ) +const serverInfoBatchSize = 100 +const timeBetweenServerInfoBatches = 10 * time.Second +const timeBetweenServerInfoLoops = 10 * time.Minute + +// ServerInfoAccessPoint is the subset of the auth server interface needed to +// reconcile server info resources. +type ServerInfoAccessPoint interface { + // GetNodeStream returns a stream of nodes. + GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server] + // GetServerInfo returns a ServerInfo by name. + GetServerInfo(ctx context.Context, name string) (types.ServerInfo, error) + // UpdateLabels updates the labels on an instance over the inventory control + // stream. + UpdateLabels(ctx context.Context, req proto.InventoryUpdateLabelsRequest) error + // GetClock returns the server clock. + GetClock() clockwork.Clock +} + // ReconcileServerInfos periodically reconciles the labels of ServerInfo // resources with their corresponding Teleport SSH servers. -func (a *Server) ReconcileServerInfos(ctx context.Context) error { - const batchSize = 100 - const timeBetweenBatches = 10 * time.Second - const timeBetweenReconciliationLoops = 10 * time.Minute - clock := a.GetClock() +func ReconcileServerInfos(ctx context.Context, ap ServerInfoAccessPoint) error { + retry, err := retryutils.NewLinear(retryutils.LinearConfig{ + First: utils.FullJitter(defaults.MaxWatcherBackoff / 10), + Step: defaults.MaxWatcherBackoff / 5, + Max: defaults.MaxWatcherBackoff, + Jitter: retryutils.NewHalfJitter(), + Clock: ap.GetClock(), + }) + if err != nil { + return trace.Wrap(err) + } for { - var failedUpdates int - // Iterate over nodes in batches. - nodeStream := a.GetNodeStream(ctx, defaults.Namespace) - var nodes []types.Server - - for moreNodes := true; moreNodes; { - nodes, moreNodes = stream.Take(nodeStream, batchSize) - updates, err := a.setLabelsOnNodes(ctx, nodes) - if err != nil { - return trace.Wrap(err) - } - failedUpdates += updates - - select { - case <-clock.After(timeBetweenBatches): - case <-ctx.Done(): - return nil - } + err := retry.For(ctx, func() error { return trace.Wrap(reconcileServerInfos(ctx, ap)) }) + if err != nil { + return trace.Wrap(err) + } + retry.Reset() + select { + case <-ap.GetClock().After(timeBetweenServerInfoLoops): + case <-ctx.Done(): + return nil } + } +} + +func reconcileServerInfos(ctx context.Context, ap ServerInfoAccessPoint) error { + var failedUpdates int + // Iterate over nodes in batches. + nodeStream := ap.GetNodeStream(ctx, apidefaults.Namespace) - // Log number of nodes that we couldn't find a control stream for. - if failedUpdates > 0 { - log.Debugf("unable to update labels on %v node(s) due to missing control stream", failedUpdates) + for { + nodes, moreNodes := stream.Take(nodeStream, serverInfoBatchSize) + updates, err := setLabelsOnNodes(ctx, ap, nodes) + if err != nil { + return trace.Wrap(err) + } + failedUpdates += updates + if !moreNodes { + break } select { - case <-clock.After(timeBetweenReconciliationLoops): + case <-ap.GetClock().After(timeBetweenServerInfoBatches): case <-ctx.Done(): return nil } } + if err := nodeStream.Done(); err != nil { + return trace.Wrap(err) + } + + // Log number of nodes that we couldn't find a control stream for. + if failedUpdates > 0 { + slog.DebugContext(ctx, "unable to update labels on nodes due to missing control stream", "failed_updates", failedUpdates) + } + return nil } // getServerInfoNames gets the names of ServerInfos that could exist for a @@ -86,7 +128,7 @@ func getServerInfoNames(node types.Server) []string { return append(names, types.ServerInfoNameFromNodeName(node.GetName())) } -func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (failedUpdates int, err error) { +func setLabelsOnNodes(ctx context.Context, ap ServerInfoAccessPoint, nodes []types.Server) (failedUpdates int, err error) { for _, node := range nodes { // EICE Node labels can't be updated using the Inventory Control Stream because there's no reverse tunnel. // Labels are updated by the DiscoveryService during 'Server.handleEC2Instances'. @@ -99,7 +141,7 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa serverInfoNames := getServerInfoNames(node) serverInfos := make([]types.ServerInfo, 0, len(serverInfoNames)) for _, name := range serverInfoNames { - si, err := a.GetServerInfo(ctx, name) + si, err := ap.GetServerInfo(ctx, name) if err == nil { serverInfos = append(serverInfos, si) } else if !trace.IsNotFound(err) { @@ -111,7 +153,7 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa } // Didn't find control stream for node, save count for logging. - if err := a.updateLabelsOnNode(ctx, node, serverInfos); trace.IsNotFound(err) { + if err := updateLabelsOnNode(ctx, ap, node, serverInfos); trace.IsNotFound(err) { failedUpdates++ } else if err != nil { return failedUpdates, trace.Wrap(err) @@ -120,14 +162,14 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa return failedUpdates, nil } -func (a *Server) updateLabelsOnNode(ctx context.Context, node types.Server, serverInfos []types.ServerInfo) error { +func updateLabelsOnNode(ctx context.Context, ap ServerInfoAccessPoint, node types.Server, serverInfos []types.ServerInfo) error { // Merge labels from server infos. Later label sets should override earlier // ones if they conflict. newLabels := make(map[string]string) for _, si := range serverInfos { maps.Copy(newLabels, si.GetNewLabels()) } - err := a.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{ + err := ap.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{ ServerID: node.GetName(), Kind: proto.LabelUpdateKind_SSHServerCloudLabels, Labels: newLabels, diff --git a/lib/auth/server_info_test.go b/lib/auth/server_info_test.go index 11b19753dc53c..2ff7d0045c918 100644 --- a/lib/auth/server_info_test.go +++ b/lib/auth/server_info_test.go @@ -22,67 +22,77 @@ import ( "context" "testing" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" ) -type mockUpstream struct { - client.UpstreamInventoryControlStream - updatedLabels map[string]string +type mockServerInfoAccessPoint struct { + clock clockwork.FakeClock + nodes []types.Server + nodesErr error + serverInfos map[string]types.ServerInfo + serverInfoErr error + updatedLabels map[string]map[string]string } -func (m *mockUpstream) Send(_ context.Context, msg proto.DownstreamInventoryMessage) error { - if labelMsg, ok := msg.(proto.DownstreamInventoryUpdateLabels); ok { - m.updatedLabels = labelMsg.Labels +func newMockServerInfoAccessPoint() *mockServerInfoAccessPoint { + return &mockServerInfoAccessPoint{ + clock: clockwork.NewFakeClock(), + serverInfos: make(map[string]types.ServerInfo), + updatedLabels: make(map[string]map[string]string), } - return nil } -func (m *mockUpstream) Recv() <-chan proto.UpstreamInventoryMessage { - return make(chan proto.UpstreamInventoryMessage) +func (m *mockServerInfoAccessPoint) GetNodeStream(_ context.Context, _ string) stream.Stream[types.Server] { + if m.nodesErr != nil { + return stream.Fail[types.Server](m.nodesErr) + } + return stream.Slice(m.nodes) } -func (m *mockUpstream) Done() <-chan struct{} { - return make(chan struct{}) +func (m *mockServerInfoAccessPoint) GetServerInfo(_ context.Context, name string) (types.ServerInfo, error) { + if m.serverInfoErr != nil { + return nil, m.serverInfoErr + } + si, ok := m.serverInfos[name] + if !ok { + return nil, trace.NotFound("no server info named %q", name) + } + return si, nil } -func (m *mockUpstream) Close() error { +func (m *mockServerInfoAccessPoint) UpdateLabels(_ context.Context, req proto.InventoryUpdateLabelsRequest) error { + m.updatedLabels[req.ServerID] = req.Labels return nil } -// TestReconcileLabels verifies that an SSH server's labels can be updated by -// upserting a corresponding ServerInfo to the auth server. -func TestReconcileLabels(t *testing.T) { +func (m *mockServerInfoAccessPoint) GetClock() clockwork.Clock { + return m.clock +} + +func TestReconcileServerInfo(t *testing.T) { t.Parallel() const serverName = "test-server" - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - // Create auth server and fake inventory stream. - clock := clockwork.NewFakeClock() - pack, err := newTestPack(ctx, t.TempDir(), WithClock(clock)) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, pack.a.Close()) - require.NoError(t, pack.bk.Close()) + awsServerInfo, err := types.NewServerInfo(types.Metadata{ + Name: types.ServerInfoNameFromAWS("my-account", "my-instance"), + }, types.ServerInfoSpecV1{ + NewLabels: map[string]string{"a": "1", "b": "2"}, }) - upstream := &mockUpstream{} - t.Cleanup(func() { - require.NoError(t, upstream.Close()) + require.NoError(t, err) + regularServerInfo, err := types.NewServerInfo(types.Metadata{ + Name: types.ServerInfoNameFromNodeName(serverName), + }, types.ServerInfoSpecV1{ + NewLabels: map[string]string{"b": "3", "c": "4"}, }) - require.NoError(t, pack.a.RegisterInventoryControlStream(upstream, proto.UpstreamInventoryHello{ - Version: teleport.Version, - ServerID: serverName, - Services: []types.SystemRole{types.RoleNode}, - })) - - // Create server. + require.NoError(t, err) server, err := types.NewServer(serverName, types.KindNode, types.ServerSpecV2{ CloudMetadata: &types.CloudMetadata{ AWS: &types.AWSInfo{ @@ -92,29 +102,65 @@ func TestReconcileLabels(t *testing.T) { }, }) require.NoError(t, err) - _, err = pack.a.UpsertNode(ctx, server) - require.NoError(t, err) - // Update the server's labels. - awsServerInfo, err := types.NewServerInfo(types.Metadata{ - Name: types.ServerInfoNameFromAWS("my-account", "my-instance"), - }, types.ServerInfoSpecV1{ - NewLabels: map[string]string{"a": "1", "b": "2"}, + t.Run("ok", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ap := newMockServerInfoAccessPoint() + ap.nodes = []types.Server{server} + ap.serverInfos = map[string]types.ServerInfo{ + awsServerInfo.GetName(): awsServerInfo, + regularServerInfo.GetName(): regularServerInfo, + } + + go ReconcileServerInfos(ctx, ap) + + // Wait until the reconciler finishes processing a batch. + ap.clock.BlockUntil(1) + // Check that the right labels were updated. + require.Equal(t, map[string]string{ + "aws/a": "1", + "aws/b": "2", + "dynamic/b": "3", + "dynamic/c": "4", + }, ap.updatedLabels[serverName]) }) - require.NoError(t, err) - require.NoError(t, pack.a.UpsertServerInfo(ctx, awsServerInfo)) - regularServerInfo, err := types.NewServerInfo(types.Metadata{ - Name: types.ServerInfoNameFromNodeName(serverName), - }, types.ServerInfoSpecV1{ - NewLabels: map[string]string{"b": "3", "c": "4"}, + t.Run("restart on error", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ap := newMockServerInfoAccessPoint() + ap.nodes = []types.Server{server} + ap.nodesErr = trace.Errorf("an error") + ap.serverInfos = map[string]types.ServerInfo{ + awsServerInfo.GetName(): awsServerInfo, + regularServerInfo.GetName(): regularServerInfo, + } + + go ReconcileServerInfos(ctx, ap) + + // Block until we hit the retryer. + ap.clock.BlockUntil(1) + // Return the error at a different place and advance to the next batch. + ap.nodesErr = nil + ap.serverInfoErr = trace.Errorf("an error") + ap.clock.Advance(defaults.MaxWatcherBackoff) + // Block until we hit the retryer again. + ap.clock.BlockUntil(1) + // Clear the error and allow a successful run. + ap.serverInfoErr = nil + ap.clock.Advance(defaults.MaxWatcherBackoff) + // Block until we hit the loop waiter (meaning the server infos were + // successfully processed). + ap.clock.BlockUntil(1) + // Check that the right labels were updated. + require.Equal(t, map[string]string{ + "aws/a": "1", + "aws/b": "2", + "dynamic/b": "3", + "dynamic/c": "4", + }, ap.updatedLabels[serverName]) }) - require.NoError(t, err) - require.NoError(t, pack.a.UpsertServerInfo(ctx, regularServerInfo)) - - go pack.a.ReconcileServerInfos(ctx) - // Wait until the reconciler finishes processing the serverinfo. - clock.BlockUntil(1) - // Check that labels were received downstream. - require.Equal(t, map[string]string{"aws/a": "1", "aws/b": "2", "dynamic/b": "3", "dynamic/c": "4"}, upstream.updatedLabels) } diff --git a/lib/service/service.go b/lib/service/service.go index 3eb2f722a7c60..d38d5d18cc681 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2421,7 +2421,7 @@ func (process *TeleportProcess) initAuthService() error { }) process.RegisterFunc("auth.server_info", func() error { - return trace.Wrap(authServer.ReconcileServerInfos(process.GracefulExitContext())) + return trace.Wrap(auth.ReconcileServerInfos(process.GracefulExitContext(), authServer)) }) // execute this when process is asked to exit: process.OnExit("auth.shutdown", func(payload any) {