Skip to content

Commit

Permalink
[v15] Fix server info reconciler not restarting after an error (#47453)
Browse files Browse the repository at this point in the history
* Add retry logic to server info reconciler

This change fixes the server info reconciler to restart on an error.

* Fix test compilation
  • Loading branch information
atburke authored Oct 11, 2024
1 parent f131e47 commit 48e9346
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 91 deletions.
108 changes: 75 additions & 33 deletions lib/auth/server_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,57 +20,99 @@ package auth

import (
"context"
"log/slog"
"maps"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/defaults"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils"
)

const serverInfoBatchSize = 100
const timeBetweenServerInfoBatches = 10 * time.Second
const timeBetweenServerInfoLoops = 10 * time.Minute

// ServerInfoAccessPoint is the subset of the auth server interface needed to
// reconcile server info resources.
type ServerInfoAccessPoint interface {
// GetNodeStream returns a stream of nodes.
GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server]
// GetServerInfo returns a ServerInfo by name.
GetServerInfo(ctx context.Context, name string) (types.ServerInfo, error)
// UpdateLabels updates the labels on an instance over the inventory control
// stream.
UpdateLabels(ctx context.Context, req proto.InventoryUpdateLabelsRequest) error
// GetClock returns the server clock.
GetClock() clockwork.Clock
}

// ReconcileServerInfos periodically reconciles the labels of ServerInfo
// resources with their corresponding Teleport SSH servers.
func (a *Server) ReconcileServerInfos(ctx context.Context) error {
const batchSize = 100
const timeBetweenBatches = 10 * time.Second
const timeBetweenReconciliationLoops = 10 * time.Minute
clock := a.GetClock()
func ReconcileServerInfos(ctx context.Context, ap ServerInfoAccessPoint) error {
retry, err := retryutils.NewLinear(retryutils.LinearConfig{
First: utils.FullJitter(defaults.MaxWatcherBackoff / 10),
Step: defaults.MaxWatcherBackoff / 5,
Max: defaults.MaxWatcherBackoff,
Jitter: retryutils.NewHalfJitter(),
Clock: ap.GetClock(),
})
if err != nil {
return trace.Wrap(err)
}

for {
var failedUpdates int
// Iterate over nodes in batches.
nodeStream := a.GetNodeStream(ctx, defaults.Namespace)
var nodes []types.Server

for moreNodes := true; moreNodes; {
nodes, moreNodes = stream.Take(nodeStream, batchSize)
updates, err := a.setLabelsOnNodes(ctx, nodes)
if err != nil {
return trace.Wrap(err)
}
failedUpdates += updates

select {
case <-clock.After(timeBetweenBatches):
case <-ctx.Done():
return nil
}
err := retry.For(ctx, func() error { return trace.Wrap(reconcileServerInfos(ctx, ap)) })
if err != nil {
return trace.Wrap(err)
}
retry.Reset()
select {
case <-ap.GetClock().After(timeBetweenServerInfoLoops):
case <-ctx.Done():
return nil
}
}
}

func reconcileServerInfos(ctx context.Context, ap ServerInfoAccessPoint) error {
var failedUpdates int
// Iterate over nodes in batches.
nodeStream := ap.GetNodeStream(ctx, apidefaults.Namespace)

// Log number of nodes that we couldn't find a control stream for.
if failedUpdates > 0 {
log.Debugf("unable to update labels on %v node(s) due to missing control stream", failedUpdates)
for {
nodes, moreNodes := stream.Take(nodeStream, serverInfoBatchSize)
updates, err := setLabelsOnNodes(ctx, ap, nodes)
if err != nil {
return trace.Wrap(err)
}
failedUpdates += updates
if !moreNodes {
break
}

select {
case <-clock.After(timeBetweenReconciliationLoops):
case <-ap.GetClock().After(timeBetweenServerInfoBatches):
case <-ctx.Done():
return nil
}
}
if err := nodeStream.Done(); err != nil {
return trace.Wrap(err)
}

// Log number of nodes that we couldn't find a control stream for.
if failedUpdates > 0 {
slog.DebugContext(ctx, "unable to update labels on nodes due to missing control stream", "failed_updates", failedUpdates)
}
return nil
}

// getServerInfoNames gets the names of ServerInfos that could exist for a
Expand All @@ -86,7 +128,7 @@ func getServerInfoNames(node types.Server) []string {
return append(names, types.ServerInfoNameFromNodeName(node.GetName()))
}

func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (failedUpdates int, err error) {
func setLabelsOnNodes(ctx context.Context, ap ServerInfoAccessPoint, nodes []types.Server) (failedUpdates int, err error) {
for _, node := range nodes {
// EICE Node labels can't be updated using the Inventory Control Stream because there's no reverse tunnel.
// Labels are updated by the DiscoveryService during 'Server.handleEC2Instances'.
Expand All @@ -99,7 +141,7 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa
serverInfoNames := getServerInfoNames(node)
serverInfos := make([]types.ServerInfo, 0, len(serverInfoNames))
for _, name := range serverInfoNames {
si, err := a.GetServerInfo(ctx, name)
si, err := ap.GetServerInfo(ctx, name)
if err == nil {
serverInfos = append(serverInfos, si)
} else if !trace.IsNotFound(err) {
Expand All @@ -111,7 +153,7 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa
}

// Didn't find control stream for node, save count for logging.
if err := a.updateLabelsOnNode(ctx, node, serverInfos); trace.IsNotFound(err) {
if err := updateLabelsOnNode(ctx, ap, node, serverInfos); trace.IsNotFound(err) {
failedUpdates++
} else if err != nil {
return failedUpdates, trace.Wrap(err)
Expand All @@ -120,14 +162,14 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa
return failedUpdates, nil
}

func (a *Server) updateLabelsOnNode(ctx context.Context, node types.Server, serverInfos []types.ServerInfo) error {
func updateLabelsOnNode(ctx context.Context, ap ServerInfoAccessPoint, node types.Server, serverInfos []types.ServerInfo) error {
// Merge labels from server infos. Later label sets should override earlier
// ones if they conflict.
newLabels := make(map[string]string)
for _, si := range serverInfos {
maps.Copy(newLabels, si.GetNewLabels())
}
err := a.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{
err := ap.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{
ServerID: node.GetName(),
Kind: proto.LabelUpdateKind_SSHServerCloudLabels,
Labels: newLabels,
Expand Down
160 changes: 103 additions & 57 deletions lib/auth/server_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,67 +22,77 @@ import (
"context"
"testing"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
)

type mockUpstream struct {
client.UpstreamInventoryControlStream
updatedLabels map[string]string
type mockServerInfoAccessPoint struct {
clock clockwork.FakeClock
nodes []types.Server
nodesErr error
serverInfos map[string]types.ServerInfo
serverInfoErr error
updatedLabels map[string]map[string]string
}

func (m *mockUpstream) Send(_ context.Context, msg proto.DownstreamInventoryMessage) error {
if labelMsg, ok := msg.(proto.DownstreamInventoryUpdateLabels); ok {
m.updatedLabels = labelMsg.Labels
func newMockServerInfoAccessPoint() *mockServerInfoAccessPoint {
return &mockServerInfoAccessPoint{
clock: clockwork.NewFakeClock(),
serverInfos: make(map[string]types.ServerInfo),
updatedLabels: make(map[string]map[string]string),
}
return nil
}

func (m *mockUpstream) Recv() <-chan proto.UpstreamInventoryMessage {
return make(chan proto.UpstreamInventoryMessage)
func (m *mockServerInfoAccessPoint) GetNodeStream(_ context.Context, _ string) stream.Stream[types.Server] {
if m.nodesErr != nil {
return stream.Fail[types.Server](m.nodesErr)
}
return stream.Slice(m.nodes)
}

func (m *mockUpstream) Done() <-chan struct{} {
return make(chan struct{})
func (m *mockServerInfoAccessPoint) GetServerInfo(_ context.Context, name string) (types.ServerInfo, error) {
if m.serverInfoErr != nil {
return nil, m.serverInfoErr
}
si, ok := m.serverInfos[name]
if !ok {
return nil, trace.NotFound("no server info named %q", name)
}
return si, nil
}

func (m *mockUpstream) Close() error {
func (m *mockServerInfoAccessPoint) UpdateLabels(_ context.Context, req proto.InventoryUpdateLabelsRequest) error {
m.updatedLabels[req.ServerID] = req.Labels
return nil
}

// TestReconcileLabels verifies that an SSH server's labels can be updated by
// upserting a corresponding ServerInfo to the auth server.
func TestReconcileLabels(t *testing.T) {
func (m *mockServerInfoAccessPoint) GetClock() clockwork.Clock {
return m.clock
}

func TestReconcileServerInfo(t *testing.T) {
t.Parallel()

const serverName = "test-server"
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

// Create auth server and fake inventory stream.
clock := clockwork.NewFakeClock()
pack, err := newTestPack(ctx, t.TempDir(), WithClock(clock))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, pack.a.Close())
require.NoError(t, pack.bk.Close())
awsServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromAWS("my-account", "my-instance"),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"a": "1", "b": "2"},
})
upstream := &mockUpstream{}
t.Cleanup(func() {
require.NoError(t, upstream.Close())
require.NoError(t, err)
regularServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromNodeName(serverName),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"b": "3", "c": "4"},
})
require.NoError(t, pack.a.RegisterInventoryControlStream(upstream, proto.UpstreamInventoryHello{
Version: teleport.Version,
ServerID: serverName,
Services: []types.SystemRole{types.RoleNode},
}))

// Create server.
require.NoError(t, err)
server, err := types.NewServer(serverName, types.KindNode, types.ServerSpecV2{
CloudMetadata: &types.CloudMetadata{
AWS: &types.AWSInfo{
Expand All @@ -92,29 +102,65 @@ func TestReconcileLabels(t *testing.T) {
},
})
require.NoError(t, err)
_, err = pack.a.UpsertNode(ctx, server)
require.NoError(t, err)

// Update the server's labels.
awsServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromAWS("my-account", "my-instance"),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"a": "1", "b": "2"},
t.Run("ok", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

ap := newMockServerInfoAccessPoint()
ap.nodes = []types.Server{server}
ap.serverInfos = map[string]types.ServerInfo{
awsServerInfo.GetName(): awsServerInfo,
regularServerInfo.GetName(): regularServerInfo,
}

go ReconcileServerInfos(ctx, ap)

// Wait until the reconciler finishes processing a batch.
ap.clock.BlockUntil(1)
// Check that the right labels were updated.
require.Equal(t, map[string]string{
"aws/a": "1",
"aws/b": "2",
"dynamic/b": "3",
"dynamic/c": "4",
}, ap.updatedLabels[serverName])
})
require.NoError(t, err)
require.NoError(t, pack.a.UpsertServerInfo(ctx, awsServerInfo))

regularServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromNodeName(serverName),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"b": "3", "c": "4"},
t.Run("restart on error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

ap := newMockServerInfoAccessPoint()
ap.nodes = []types.Server{server}
ap.nodesErr = trace.Errorf("an error")
ap.serverInfos = map[string]types.ServerInfo{
awsServerInfo.GetName(): awsServerInfo,
regularServerInfo.GetName(): regularServerInfo,
}

go ReconcileServerInfos(ctx, ap)

// Block until we hit the retryer.
ap.clock.BlockUntil(1)
// Return the error at a different place and advance to the next batch.
ap.nodesErr = nil
ap.serverInfoErr = trace.Errorf("an error")
ap.clock.Advance(defaults.MaxWatcherBackoff)
// Block until we hit the retryer again.
ap.clock.BlockUntil(1)
// Clear the error and allow a successful run.
ap.serverInfoErr = nil
ap.clock.Advance(defaults.MaxWatcherBackoff)
// Block until we hit the loop waiter (meaning the server infos were
// successfully processed).
ap.clock.BlockUntil(1)
// Check that the right labels were updated.
require.Equal(t, map[string]string{
"aws/a": "1",
"aws/b": "2",
"dynamic/b": "3",
"dynamic/c": "4",
}, ap.updatedLabels[serverName])
})
require.NoError(t, err)
require.NoError(t, pack.a.UpsertServerInfo(ctx, regularServerInfo))

go pack.a.ReconcileServerInfos(ctx)
// Wait until the reconciler finishes processing the serverinfo.
clock.BlockUntil(1)
// Check that labels were received downstream.
require.Equal(t, map[string]string{"aws/a": "1", "aws/b": "2", "dynamic/b": "3", "dynamic/c": "4"}, upstream.updatedLabels)
}
2 changes: 1 addition & 1 deletion lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2421,7 +2421,7 @@ func (process *TeleportProcess) initAuthService() error {
})

process.RegisterFunc("auth.server_info", func() error {
return trace.Wrap(authServer.ReconcileServerInfos(process.GracefulExitContext()))
return trace.Wrap(auth.ReconcileServerInfos(process.GracefulExitContext(), authServer))
})
// execute this when process is asked to exit:
process.OnExit("auth.shutdown", func(payload any) {
Expand Down

0 comments on commit 48e9346

Please # to comment.