Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[v15] Fix server info reconciler not restarting after an error #47453

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 75 additions & 33 deletions lib/auth/server_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,57 +20,99 @@ package auth

import (
"context"
"log/slog"
"maps"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/defaults"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils"
)

const serverInfoBatchSize = 100
const timeBetweenServerInfoBatches = 10 * time.Second
const timeBetweenServerInfoLoops = 10 * time.Minute

// ServerInfoAccessPoint is the subset of the auth server interface needed to
// reconcile server info resources.
type ServerInfoAccessPoint interface {
// GetNodeStream returns a stream of nodes.
GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server]
// GetServerInfo returns a ServerInfo by name.
GetServerInfo(ctx context.Context, name string) (types.ServerInfo, error)
// UpdateLabels updates the labels on an instance over the inventory control
// stream.
UpdateLabels(ctx context.Context, req proto.InventoryUpdateLabelsRequest) error
// GetClock returns the server clock.
GetClock() clockwork.Clock
}

// ReconcileServerInfos periodically reconciles the labels of ServerInfo
// resources with their corresponding Teleport SSH servers.
func (a *Server) ReconcileServerInfos(ctx context.Context) error {
const batchSize = 100
const timeBetweenBatches = 10 * time.Second
const timeBetweenReconciliationLoops = 10 * time.Minute
clock := a.GetClock()
func ReconcileServerInfos(ctx context.Context, ap ServerInfoAccessPoint) error {
retry, err := retryutils.NewLinear(retryutils.LinearConfig{
First: utils.FullJitter(defaults.MaxWatcherBackoff / 10),
Step: defaults.MaxWatcherBackoff / 5,
Max: defaults.MaxWatcherBackoff,
Jitter: retryutils.NewHalfJitter(),
Clock: ap.GetClock(),
})
if err != nil {
return trace.Wrap(err)
}

for {
var failedUpdates int
// Iterate over nodes in batches.
nodeStream := a.GetNodeStream(ctx, defaults.Namespace)
var nodes []types.Server

for moreNodes := true; moreNodes; {
nodes, moreNodes = stream.Take(nodeStream, batchSize)
updates, err := a.setLabelsOnNodes(ctx, nodes)
if err != nil {
return trace.Wrap(err)
}
failedUpdates += updates

select {
case <-clock.After(timeBetweenBatches):
case <-ctx.Done():
return nil
}
err := retry.For(ctx, func() error { return trace.Wrap(reconcileServerInfos(ctx, ap)) })
if err != nil {
return trace.Wrap(err)
}
retry.Reset()
select {
case <-ap.GetClock().After(timeBetweenServerInfoLoops):
case <-ctx.Done():
return nil
}
}
}

func reconcileServerInfos(ctx context.Context, ap ServerInfoAccessPoint) error {
var failedUpdates int
// Iterate over nodes in batches.
nodeStream := ap.GetNodeStream(ctx, apidefaults.Namespace)

// Log number of nodes that we couldn't find a control stream for.
if failedUpdates > 0 {
log.Debugf("unable to update labels on %v node(s) due to missing control stream", failedUpdates)
for {
nodes, moreNodes := stream.Take(nodeStream, serverInfoBatchSize)
updates, err := setLabelsOnNodes(ctx, ap, nodes)
if err != nil {
return trace.Wrap(err)
}
failedUpdates += updates
if !moreNodes {
break
}

select {
case <-clock.After(timeBetweenReconciliationLoops):
case <-ap.GetClock().After(timeBetweenServerInfoBatches):
case <-ctx.Done():
return nil
}
}
if err := nodeStream.Done(); err != nil {
return trace.Wrap(err)
}

// Log number of nodes that we couldn't find a control stream for.
if failedUpdates > 0 {
slog.DebugContext(ctx, "unable to update labels on nodes due to missing control stream", "failed_updates", failedUpdates)
}
return nil
}

// getServerInfoNames gets the names of ServerInfos that could exist for a
Expand All @@ -86,7 +128,7 @@ func getServerInfoNames(node types.Server) []string {
return append(names, types.ServerInfoNameFromNodeName(node.GetName()))
}

func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (failedUpdates int, err error) {
func setLabelsOnNodes(ctx context.Context, ap ServerInfoAccessPoint, nodes []types.Server) (failedUpdates int, err error) {
for _, node := range nodes {
// EICE Node labels can't be updated using the Inventory Control Stream because there's no reverse tunnel.
// Labels are updated by the DiscoveryService during 'Server.handleEC2Instances'.
Expand All @@ -99,7 +141,7 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa
serverInfoNames := getServerInfoNames(node)
serverInfos := make([]types.ServerInfo, 0, len(serverInfoNames))
for _, name := range serverInfoNames {
si, err := a.GetServerInfo(ctx, name)
si, err := ap.GetServerInfo(ctx, name)
if err == nil {
serverInfos = append(serverInfos, si)
} else if !trace.IsNotFound(err) {
Expand All @@ -111,7 +153,7 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa
}

// Didn't find control stream for node, save count for logging.
if err := a.updateLabelsOnNode(ctx, node, serverInfos); trace.IsNotFound(err) {
if err := updateLabelsOnNode(ctx, ap, node, serverInfos); trace.IsNotFound(err) {
failedUpdates++
} else if err != nil {
return failedUpdates, trace.Wrap(err)
Expand All @@ -120,14 +162,14 @@ func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (fa
return failedUpdates, nil
}

func (a *Server) updateLabelsOnNode(ctx context.Context, node types.Server, serverInfos []types.ServerInfo) error {
func updateLabelsOnNode(ctx context.Context, ap ServerInfoAccessPoint, node types.Server, serverInfos []types.ServerInfo) error {
// Merge labels from server infos. Later label sets should override earlier
// ones if they conflict.
newLabels := make(map[string]string)
for _, si := range serverInfos {
maps.Copy(newLabels, si.GetNewLabels())
}
err := a.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{
err := ap.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{
ServerID: node.GetName(),
Kind: proto.LabelUpdateKind_SSHServerCloudLabels,
Labels: newLabels,
Expand Down
171 changes: 114 additions & 57 deletions lib/auth/server_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,67 +22,78 @@ import (
"context"
"testing"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils"
)

type mockUpstream struct {
client.UpstreamInventoryControlStream
updatedLabels map[string]string
type mockServerInfoAccessPoint struct {
clock clockwork.FakeClock
nodes []types.Server
nodesErr error
serverInfos map[string]types.ServerInfo
serverInfoErr error
updatedLabels map[string]map[string]string
}

func (m *mockUpstream) Send(_ context.Context, msg proto.DownstreamInventoryMessage) error {
if labelMsg, ok := msg.(proto.DownstreamInventoryUpdateLabels); ok {
m.updatedLabels = labelMsg.Labels
func newMockServerInfoAccessPoint() *mockServerInfoAccessPoint {
return &mockServerInfoAccessPoint{
clock: clockwork.NewFakeClock(),
serverInfos: make(map[string]types.ServerInfo),
updatedLabels: make(map[string]map[string]string),
}
return nil
}

func (m *mockUpstream) Recv() <-chan proto.UpstreamInventoryMessage {
return make(chan proto.UpstreamInventoryMessage)
func (m *mockServerInfoAccessPoint) GetNodeStream(_ context.Context, _ string) stream.Stream[types.Server] {
if m.nodesErr != nil {
return stream.Fail[types.Server](m.nodesErr)
}
return stream.Slice(m.nodes)
}

func (m *mockUpstream) Done() <-chan struct{} {
return make(chan struct{})
func (m *mockServerInfoAccessPoint) GetServerInfo(_ context.Context, name string) (types.ServerInfo, error) {
if m.serverInfoErr != nil {
return nil, m.serverInfoErr
}
si, ok := m.serverInfos[name]
if !ok {
return nil, trace.NotFound("no server info named %q", name)
}
return si, nil
}

func (m *mockUpstream) Close() error {
func (m *mockServerInfoAccessPoint) UpdateLabels(_ context.Context, req proto.InventoryUpdateLabelsRequest) error {
m.updatedLabels[req.ServerID] = req.Labels
return nil
}

// TestReconcileLabels verifies that an SSH server's labels can be updated by
// upserting a corresponding ServerInfo to the auth server.
func TestReconcileLabels(t *testing.T) {
func (m *mockServerInfoAccessPoint) GetClock() clockwork.Clock {
return m.clock
}

func TestReconcileServerInfo(t *testing.T) {
t.Parallel()

const serverName = "test-server"
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

// Create auth server and fake inventory stream.
clock := clockwork.NewFakeClock()
pack, err := newTestPack(ctx, t.TempDir(), WithClock(clock))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, pack.a.Close())
require.NoError(t, pack.bk.Close())
awsServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromAWS("my-account", "my-instance"),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"a": "1", "b": "2"},
})
upstream := &mockUpstream{}
t.Cleanup(func() {
require.NoError(t, upstream.Close())
require.NoError(t, err)
regularServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromNodeName(serverName),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"b": "3", "c": "4"},
})
require.NoError(t, pack.a.RegisterInventoryControlStream(upstream, proto.UpstreamInventoryHello{
Version: teleport.Version,
ServerID: serverName,
Services: []types.SystemRole{types.RoleNode},
}))

// Create server.
require.NoError(t, err)
server, err := types.NewServer(serverName, types.KindNode, types.ServerSpecV2{
CloudMetadata: &types.CloudMetadata{
AWS: &types.AWSInfo{
Expand All @@ -92,29 +103,75 @@ func TestReconcileLabels(t *testing.T) {
},
})
require.NoError(t, err)
_, err = pack.a.UpsertNode(ctx, server)
require.NoError(t, err)

// Update the server's labels.
awsServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromAWS("my-account", "my-instance"),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"a": "1", "b": "2"},
t.Run("ok", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

ap := newMockServerInfoAccessPoint()
ap.nodes = []types.Server{server}
ap.serverInfos = map[string]types.ServerInfo{
awsServerInfo.GetName(): awsServerInfo,
regularServerInfo.GetName(): regularServerInfo,
}

utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{
Name: "ReconcileServerInfos",
Task: func(ctx context.Context) error {
return trace.Wrap(ReconcileServerInfos(ctx, ap))
},
})

// Wait until the reconciler finishes processing a batch.
ap.clock.BlockUntil(1)
// Check that the right labels were updated.
require.Equal(t, map[string]string{
"aws/a": "1",
"aws/b": "2",
"dynamic/b": "3",
"dynamic/c": "4",
}, ap.updatedLabels[serverName])
})
require.NoError(t, err)
require.NoError(t, pack.a.UpsertServerInfo(ctx, awsServerInfo))

regularServerInfo, err := types.NewServerInfo(types.Metadata{
Name: types.ServerInfoNameFromNodeName(serverName),
}, types.ServerInfoSpecV1{
NewLabels: map[string]string{"b": "3", "c": "4"},
t.Run("restart on error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

ap := newMockServerInfoAccessPoint()
ap.nodes = []types.Server{server}
ap.nodesErr = trace.Errorf("an error")
ap.serverInfos = map[string]types.ServerInfo{
awsServerInfo.GetName(): awsServerInfo,
regularServerInfo.GetName(): regularServerInfo,
}

utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{
Name: "ReconcileServerInfos",
Task: func(ctx context.Context) error {
return trace.Wrap(ReconcileServerInfos(ctx, ap))
},
})

// Block until we hit the retryer.
ap.clock.BlockUntil(1)
// Return the error at a different place and advance to the next batch.
ap.nodesErr = nil
ap.serverInfoErr = trace.Errorf("an error")
ap.clock.Advance(defaults.MaxWatcherBackoff)
// Block until we hit the retryer again.
ap.clock.BlockUntil(1)
// Clear the error and allow a successful run.
ap.serverInfoErr = nil
ap.clock.Advance(defaults.MaxWatcherBackoff)
// Block until we hit the loop waiter (meaning the server infos were
// successfully processed).
ap.clock.BlockUntil(1)
// Check that the right labels were updated.
require.Equal(t, map[string]string{
"aws/a": "1",
"aws/b": "2",
"dynamic/b": "3",
"dynamic/c": "4",
}, ap.updatedLabels[serverName])
})
require.NoError(t, err)
require.NoError(t, pack.a.UpsertServerInfo(ctx, regularServerInfo))

go pack.a.ReconcileServerInfos(ctx)
// Wait until the reconciler finishes processing the serverinfo.
clock.BlockUntil(1)
// Check that labels were received downstream.
require.Equal(t, map[string]string{"aws/a": "1", "aws/b": "2", "dynamic/b": "3", "dynamic/c": "4"}, upstream.updatedLabels)
}
2 changes: 1 addition & 1 deletion lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2421,7 +2421,7 @@ func (process *TeleportProcess) initAuthService() error {
})

process.RegisterFunc("auth.server_info", func() error {
return trace.Wrap(authServer.ReconcileServerInfos(process.GracefulExitContext()))
return trace.Wrap(auth.ReconcileServerInfos(process.GracefulExitContext(), authServer))
})
// execute this when process is asked to exit:
process.OnExit("auth.shutdown", func(payload any) {
Expand Down
Loading