From d9508672079f0a69c6db6208c14b8aca7059515b Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Wed, 3 Feb 2021 03:08:13 -0800 Subject: [PATCH] Fix cluster peer HTTP SRV discovery Signed-off-by: Brad Davidson --- embed/config.go | 18 +++++-- embed/config_test.go | 84 +++++++++++++++++++++++++++++- pkg/srv/srv_test.go | 118 ++++++++++++++++++++++++++++++++++--------- 3 files changed, 191 insertions(+), 29 deletions(-) diff --git a/embed/config.go b/embed/config.go index d76141339ef..9e0b87fc9da 100644 --- a/embed/config.go +++ b/embed/config.go @@ -39,6 +39,7 @@ import ( "go.etcd.io/etcd/pkg/types" bolt "go.etcd.io/bbolt" + "go.uber.org/multierr" "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/crypto/bcrypt" @@ -93,6 +94,9 @@ var ( defaultHostname string defaultHostStatus error + + // indirection for testing + getCluster = srv.GetCluster ) var ( @@ -726,6 +730,8 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok } else { plog.Errorf("couldn't resolve during SRV discovery (%v)", cerr) } + } + if len(clusterStrs) == 0 { return nil, "", cerr } for _, s := range clusterStrs { @@ -756,6 +762,10 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok } // GetDNSClusterNames uses DNS SRV records to get a list of initial nodes for cluster bootstrapping. +// This function will return a list of one or more nodes, as well as any errors encountered while +// performing service discovery. +// Note: Because this checks multiple sets of SRV records, discovery should only be considered to have +// failed if the returned node list is empty. func (cfg *Config) GetDNSClusterNames() ([]string, error) { var ( clusterStrs []string @@ -770,7 +780,7 @@ func (cfg *Config) GetDNSClusterNames() ([]string, error) { // Use both etcd-server-ssl and etcd-server for discovery. // Combine the results if both are available. - clusterStrs, cerr = srv.GetCluster("https", "etcd-server-ssl"+serviceNameSuffix, cfg.Name, cfg.DNSCluster, cfg.AdvertisePeerUrls) + clusterStrs, cerr = getCluster("https", "etcd-server-ssl"+serviceNameSuffix, cfg.Name, cfg.DNSCluster, cfg.AdvertisePeerUrls) if cerr != nil { clusterStrs = make([]string, 0) } @@ -787,8 +797,8 @@ func (cfg *Config) GetDNSClusterNames() ([]string, error) { ) } - defaultHTTPClusterStrs, httpCerr := srv.GetCluster("http", "etcd-server"+serviceNameSuffix, cfg.Name, cfg.DNSCluster, cfg.AdvertisePeerUrls) - if httpCerr != nil { + defaultHTTPClusterStrs, httpCerr := getCluster("http", "etcd-server"+serviceNameSuffix, cfg.Name, cfg.DNSCluster, cfg.AdvertisePeerUrls) + if httpCerr == nil { clusterStrs = append(clusterStrs, defaultHTTPClusterStrs...) } if lg != nil { @@ -804,7 +814,7 @@ func (cfg *Config) GetDNSClusterNames() ([]string, error) { ) } - return clusterStrs, cerr + return clusterStrs, multierr.Combine(cerr, httpCerr) } func (cfg Config) InitialClusterFromName(name string) (ret string) { diff --git a/embed/config_test.go b/embed/config_test.go index dad5a45a8f7..c9ff9f01c45 100644 --- a/embed/config_test.go +++ b/embed/config_test.go @@ -18,17 +18,25 @@ import ( "crypto/tls" "fmt" "io/ioutil" + "net" "net/url" "os" "testing" "time" "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/pkg/srv" "go.etcd.io/etcd/pkg/transport" + "go.etcd.io/etcd/pkg/types" "sigs.k8s.io/yaml" ) +func notFoundErr(service, domain string) error { + name := fmt.Sprintf("_%s._tcp.%s", service, domain) + return &net.DNSError{Err: "no such host", Name: name, Server: "10.0.0.53:53", IsTimeout: false, IsTemporary: false, IsNotFound: true} +} + func TestConfigFileOtherFields(t *testing.T) { ctls := securityConfig{TrustedCAFile: "cca", CertFile: "ccert", KeyFile: "ckey"} ptls := securityConfig{TrustedCAFile: "pca", CertFile: "pcert", KeyFile: "pkey"} @@ -86,7 +94,7 @@ func TestUpdateDefaultClusterFromName(t *testing.T) { // in case of 'etcd --name=abc' exp := fmt.Sprintf("%s=%s://localhost:%s", cfg.Name, oldscheme, lpport) - cfg.UpdateDefaultClusterFromName(defaultInitialCluster) + _, _ = cfg.UpdateDefaultClusterFromName(defaultInitialCluster) if exp != cfg.InitialCluster { t.Fatalf("initial-cluster expected %q, got %q", exp, cfg.InitialCluster) } @@ -281,3 +289,77 @@ func TestTLSVersionMinMax(t *testing.T) { }) } } + +func TestPeerURLsMapAndTokenFromSRV(t *testing.T) { + defer func() { getCluster = srv.GetCluster }() + tests := []struct { + withSSL []string + withoutSSL []string + apurls []string + wurls string + werr bool + }{ + { + []string{}, + []string{}, + []string{"http://localhost:2380"}, + "", + true, + }, + { + []string{"1.example.com=https://1.example.com:2380", "0=https://2.example.com:2380", "1=https://3.example.com:2380"}, + []string{}, + []string{"https://1.example.com:2380"}, + "0=https://2.example.com:2380,1.example.com=https://1.example.com:2380,1=https://3.example.com:2380", + false, + }, + { + []string{"1.example.com=https://1.example.com:2380"}, + []string{"0=http://2.example.com:2380", "1=http://3.example.com:2380"}, + []string{"https://1.example.com:2380"}, + "0=http://2.example.com:2380,1.example.com=https://1.example.com:2380,1=http://3.example.com:2380", + false, + }, + { + []string{}, + []string{"1.example.com=http://1.example.com:2380", "0=http://2.example.com:2380", "1=http://3.example.com:2380"}, + []string{"http://1.example.com:2380"}, + "0=http://2.example.com:2380,1.example.com=http://1.example.com:2380,1=http://3.example.com:2380", + false, + }, + } + hasErr := func(err error) bool { + return err != nil + } + for i, tt := range tests { + getCluster = func(serviceScheme string, service string, name string, dns string, apurls types.URLs) ([]string, error) { + var urls []string + if serviceScheme == "https" && service == "etcd-server-ssl" { + urls = tt.withSSL + } else if serviceScheme == "http" && service == "etcd-server" { + urls = tt.withoutSSL + } + if len(urls) > 0 { + return urls, nil + } + return urls, notFoundErr(service, dns) + } + cfg := NewConfig() + cfg.Name = "1.example.com" + cfg.InitialCluster = "" + cfg.InitialClusterToken = "" + cfg.DNSCluster = "example.com" + cfg.AdvertisePeerUrls = types.MustNewURLs(tt.apurls) + if err := cfg.Validate(); err != nil { + t.Errorf("#%d: failed to validate test Config: %v", i, err) + continue + } + urlsmap, _, err := cfg.PeerURLsMapAndToken("etcd") + if urlsmap.String() != tt.wurls { + t.Errorf("#%d: urlsmap = %s, want = %s", i, urlsmap.String(), tt.wurls) + } + if hasErr(err) != tt.werr { + t.Errorf("#%d: err = %v, want = %v", i, err, tt.werr) + } + } +} diff --git a/pkg/srv/srv_test.go b/pkg/srv/srv_test.go index 24a7cf22d5d..a962798ed10 100644 --- a/pkg/srv/srv_test.go +++ b/pkg/srv/srv_test.go @@ -16,6 +16,7 @@ package srv import ( "errors" + "fmt" "net" "reflect" "strings" @@ -24,12 +25,21 @@ import ( "go.etcd.io/etcd/pkg/testutil" ) +func notFoundErr(service, proto, domain string) error { + name := fmt.Sprintf("_%s._%s.%s", service, proto, domain) + return &net.DNSError{Err: "no such host", Name: name, Server: "10.0.0.53:53", IsTimeout: false, IsTemporary: false, IsNotFound: true} +} + func TestSRVGetCluster(t *testing.T) { defer func() { lookupSRV = net.LookupSRV resolveTCPAddr = net.ResolveTCPAddr }() + hasErr := func(err error) bool { + return err != nil + } + name := "dnsClusterTest" dns := map[string]string{ "1.example.com.:2480": "10.0.0.1:2480", @@ -42,57 +52,72 @@ func TestSRVGetCluster(t *testing.T) { {Target: "2.example.com.", Port: 2480}, {Target: "3.example.com.", Port: 2480}, } + srvNone := []*net.SRV{} tests := []struct { - scheme string - records []*net.SRV - urls []string - - expected string + service string + scheme string + withSSL []*net.SRV + withoutSSL []*net.SRV + urls []string + expected string + werr bool }{ { + "etcd-server-ssl", "https", - []*net.SRV{}, + srvNone, + srvNone, nil, - "", + true, }, { + "etcd-server-ssl", "https", srvAll, + srvNone, nil, - "0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480", + false, }, { + "etcd-server", "http", + srvNone, srvAll, nil, - "0=http://1.example.com:2480,1=http://2.example.com:2480,2=http://3.example.com:2480", + false, }, { + "etcd-server-ssl", "https", srvAll, + srvNone, []string{"https://10.0.0.1:2480"}, - "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480", + false, }, // matching local member with resolved addr and return unresolved hostnames { + "etcd-server-ssl", "https", srvAll, + srvNone, []string{"https://10.0.0.1:2480"}, - "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480", + false, }, // reject if apurls are TLS but SRV is only http { + "etcd-server", "http", + srvNone, srvAll, []string{"https://10.0.0.1:2480"}, - "0=http://2.example.com:2480,1=http://3.example.com:2480", + false, }, } @@ -109,12 +134,26 @@ func TestSRVGetCluster(t *testing.T) { for i, tt := range tests { lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) { - return "", tt.records, nil + if service == "etcd-server-ssl" { + if len(tt.withSSL) > 0 { + return "", tt.withSSL, nil + } + return "", nil, notFoundErr(service, proto, domain) + } + if service == "etcd-server" { + if len(tt.withoutSSL) > 0 { + return "", tt.withoutSSL, nil + } + return "", nil, notFoundErr(service, proto, domain) + } + return "", nil, errors.New("unknown service in mock") } + urls := testutil.MustNewURLs(t, tt.urls) - str, err := GetCluster(tt.scheme, "etcd-server", name, "example.com", urls) - if err != nil { - t.Fatalf("%d: err: %#v", i, err) + str, err := GetCluster(tt.scheme, tt.service, name, "example.com", urls) + + if hasErr(err) != tt.werr { + t.Fatalf("%d: err = %#v, want = %#v", i, err, tt.werr) } if strings.Join(str, ",") != tt.expected { t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected) @@ -125,15 +164,31 @@ func TestSRVGetCluster(t *testing.T) { func TestSRVDiscover(t *testing.T) { defer func() { lookupSRV = net.LookupSRV }() + hasErr := func(err error) bool { + return err != nil + } + tests := []struct { withSSL []*net.SRV withoutSSL []*net.SRV expected []string + werr bool }{ { []*net.SRV{}, []*net.SRV{}, []string{}, + true, + }, + { + []*net.SRV{}, + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []string{"http://10.0.0.1:2480", "http://10.0.0.2:2480", "http://10.0.0.3:2480"}, + false, }, { []*net.SRV{ @@ -143,6 +198,7 @@ func TestSRVDiscover(t *testing.T) { }, []*net.SRV{}, []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"}, + false, }, { []*net.SRV{ @@ -154,6 +210,7 @@ func TestSRVDiscover(t *testing.T) { {Target: "10.0.0.1", Port: 7001}, }, []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, + false, }, { []*net.SRV{ @@ -165,6 +222,7 @@ func TestSRVDiscover(t *testing.T) { {Target: "10.0.0.1", Port: 7001}, }, []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, + false, }, { []*net.SRV{ @@ -174,29 +232,41 @@ func TestSRVDiscover(t *testing.T) { }, []*net.SRV{}, []string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"}, + false, }, } for i, tt := range tests { lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) { if service == "etcd-client-ssl" { - return "", tt.withSSL, nil + if len(tt.withSSL) > 0 { + return "", tt.withSSL, nil + } + return "", nil, notFoundErr(service, proto, domain) } if service == "etcd-client" { - return "", tt.withoutSSL, nil + if len(tt.withoutSSL) > 0 { + return "", tt.withoutSSL, nil + } + return "", nil, notFoundErr(service, proto, domain) } return "", nil, errors.New("Unknown service in mock") } srvs, err := GetClient("etcd-client", "example.com", "") - if err != nil { - t.Fatalf("%d: err: %#v", i, err) - } - if !reflect.DeepEqual(srvs.Endpoints, tt.expected) { - t.Errorf("#%d: endpoints = %v, want %v", i, srvs.Endpoints, tt.expected) + if hasErr(err) != tt.werr { + t.Fatalf("%d: err = %#v, want = %#v", i, err, tt.werr) + } + if srvs == nil { + if len(tt.expected) > 0 { + t.Errorf("#%d: srvs = nil, want non-nil", i) + } + } else { + if !reflect.DeepEqual(srvs.Endpoints, tt.expected) { + t.Errorf("#%d: endpoints = %v, want = %v", i, srvs.Endpoints, tt.expected) + } } - } }