From a21eedf414bb906efab78176071e73a2731fa0cb Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 12 Oct 2023 10:09:38 -0700 Subject: [PATCH] internal/testutils: add a new test type that implements resolver.ClientConn (#6668) --- .../weightedtarget/weightedtarget_test.go | 32 +- clientconn.go | 2 +- .../gracefulswitch/gracefulswitch_test.go | 4 +- internal/balancergroup/balancergroup_test.go | 14 +- internal/resolver/dns/dns_resolver.go | 69 +- internal/resolver/dns/dns_resolver_test.go | 2013 +++++++---------- .../resolver/dns/fake_net_resolver_test.go | 123 + internal/resolver/dns/internal/internal.go | 70 + internal/testutils/balancer.go | 47 +- internal/testutils/resolver.go | 70 + .../balancer/clusterimpl/balancer_test.go | 14 +- .../clustermanager/clustermanager_test.go | 18 +- .../outlierdetection/balancer_test.go | 4 +- .../balancer/priority/balancer_test.go | 40 +- .../priority/ignore_resolve_now_test.go | 2 +- .../balancer/ringhash/ringhash_test.go | 4 +- .../balancer/wrrlocality/balancer_test.go | 2 +- 17 files changed, 1158 insertions(+), 1370 deletions(-) create mode 100644 internal/resolver/dns/fake_net_resolver_test.go create mode 100644 internal/resolver/dns/internal/internal.go create mode 100644 internal/testutils/resolver.go diff --git a/balancer/weightedtarget/weightedtarget_test.go b/balancer/weightedtarget/weightedtarget_test.go index 1fe16039e341..98f28a81891b 100644 --- a/balancer/weightedtarget/weightedtarget_test.go +++ b/balancer/weightedtarget/weightedtarget_test.go @@ -167,7 +167,7 @@ func init() { // glue code in weighted_target. It also tests an empty target config update, // which should trigger a transient failure state update. func (s) TestWeightedTarget(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -329,7 +329,7 @@ func (s) TestWeightedTarget(t *testing.T) { // have a weighted target balancer will one sub-balancer, and we add and remove // backends from the subBalancer. func (s) TestWeightedTarget_OneSubBalancer_AddRemoveBackend(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -427,7 +427,7 @@ func (s) TestWeightedTarget_OneSubBalancer_AddRemoveBackend(t *testing.T) { // TestWeightedTarget_TwoSubBalancers_OneBackend tests the case where we have a // weighted target balancer with two sub-balancers, each with one backend. func (s) TestWeightedTarget_TwoSubBalancers_OneBackend(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -493,7 +493,7 @@ func (s) TestWeightedTarget_TwoSubBalancers_OneBackend(t *testing.T) { // a weighted target balancer with two sub-balancers, each with more than one // backend. func (s) TestWeightedTarget_TwoSubBalancers_MoreBackends(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -637,7 +637,7 @@ func (s) TestWeightedTarget_TwoSubBalancers_MoreBackends(t *testing.T) { // case where we have a weighted target balancer with two sub-balancers of // differing weights. func (s) TestWeightedTarget_TwoSubBalancers_DifferentWeight_MoreBackends(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -718,7 +718,7 @@ func (s) TestWeightedTarget_TwoSubBalancers_DifferentWeight_MoreBackends(t *test // have a weighted target balancer with three sub-balancers and we remove one of // the subBalancers. func (s) TestWeightedTarget_ThreeSubBalancers_RemoveBalancer(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -879,7 +879,7 @@ func (s) TestWeightedTarget_ThreeSubBalancers_RemoveBalancer(t *testing.T) { // where we have a weighted target balancer with two sub-balancers, and we // change the weight of these subBalancers. func (s) TestWeightedTarget_TwoSubBalancers_ChangeWeight_MoreBackends(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -997,7 +997,7 @@ func (s) TestWeightedTarget_TwoSubBalancers_ChangeWeight_MoreBackends(t *testing // the picks won't fail with transient_failure, and should instead wait for the // other sub-balancer. func (s) TestWeightedTarget_InitOneSubBalancerTransientFailure(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -1059,7 +1059,7 @@ func (s) TestWeightedTarget_InitOneSubBalancerTransientFailure(t *testing.T) { // connecting, the overall state stays in transient_failure, and all picks // return transient failure error. func (s) TestBalancerGroup_SubBalancerTurnsConnectingFromTransientFailure(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -1141,7 +1141,7 @@ func (s) TestBalancerGroup_SubBalancerTurnsConnectingFromTransientFailure(t *tes // Verify that a SubConn is created with the expected address and hierarchy // path cleared. -func verifyAddressInNewSubConn(t *testing.T, cc *testutils.TestClientConn, addr resolver.Address) { +func verifyAddressInNewSubConn(t *testing.T, cc *testutils.BalancerClientConn, addr resolver.Address) { t.Helper() gotAddr := <-cc.NewSubConnAddrsCh @@ -1163,7 +1163,7 @@ type subConnWithAddr struct { // // Returned value is a map from subBalancer (identified by its config) to // subConns created by it. -func waitForNewSubConns(t *testing.T, cc *testutils.TestClientConn, num int) map[string][]subConnWithAddr { +func waitForNewSubConns(t *testing.T, cc *testutils.BalancerClientConn, num int) map[string][]subConnWithAddr { t.Helper() scs := make(map[string][]subConnWithAddr) @@ -1233,7 +1233,7 @@ func init() { // TestInitialIdle covers the case that if the child reports Idle, the overall // state will be Idle. func (s) TestInitialIdle(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -1274,7 +1274,7 @@ func (s) TestInitialIdle(t *testing.T) { // TestIgnoreSubBalancerStateTransitions covers the case that if the child reports a // transition from TF to Connecting, the overall state will still be TF. func (s) TestIgnoreSubBalancerStateTransitions(t *testing.T) { - cc := &tcc{TestClientConn: testutils.NewTestClientConn(t)} + cc := &tcc{BalancerClientConn: testutils.NewBalancerClientConn(t)} wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) defer wtb.Close() @@ -1314,17 +1314,17 @@ func (s) TestIgnoreSubBalancerStateTransitions(t *testing.T) { // tcc wraps a testutils.TestClientConn but stores all state transitions in a // slice. type tcc struct { - *testutils.TestClientConn + *testutils.BalancerClientConn states []balancer.State } func (t *tcc) UpdateState(bs balancer.State) { t.states = append(t.states, bs) - t.TestClientConn.UpdateState(bs) + t.BalancerClientConn.UpdateState(bs) } func (s) TestUpdateStatePauses(t *testing.T) { - cc := &tcc{TestClientConn: testutils.NewTestClientConn(t)} + cc := &tcc{BalancerClientConn: testutils.NewBalancerClientConn(t)} balFuncs := stub.BalancerFuncs{ UpdateClientConnState: func(bd *stub.BalancerData, s balancer.ClientConnState) error { diff --git a/clientconn.go b/clientconn.go index 429c389e4730..c7bf6849f07e 100644 --- a/clientconn.go +++ b/clientconn.go @@ -48,9 +48,9 @@ import ( "google.golang.org/grpc/status" _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. - _ "google.golang.org/grpc/internal/resolver/dns" // To register dns resolver. _ "google.golang.org/grpc/internal/resolver/passthrough" // To register passthrough resolver. _ "google.golang.org/grpc/internal/resolver/unix" // To register unix resolver. + _ "google.golang.org/grpc/resolver/dns" // To register dns resolver. ) const ( diff --git a/internal/balancer/gracefulswitch/gracefulswitch_test.go b/internal/balancer/gracefulswitch/gracefulswitch_test.go index e92b09dbe4cf..80fe651ea29f 100644 --- a/internal/balancer/gracefulswitch/gracefulswitch_test.go +++ b/internal/balancer/gracefulswitch/gracefulswitch_test.go @@ -49,8 +49,8 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } -func setup(t *testing.T) (*testutils.TestClientConn, *Balancer) { - tcc := testutils.NewTestClientConn(t) +func setup(t *testing.T) (*testutils.BalancerClientConn, *Balancer) { + tcc := testutils.NewBalancerClientConn(t) return tcc, NewBalancer(tcc, balancer.BuildOptions{}) } diff --git a/internal/balancergroup/balancergroup_test.go b/internal/balancergroup/balancergroup_test.go index c57cf60ca84b..8daab7eeba72 100644 --- a/internal/balancergroup/balancergroup_test.go +++ b/internal/balancergroup/balancergroup_test.go @@ -73,7 +73,7 @@ func Test(t *testing.T) { // - b3, weight 1, backends [1,2] // Start the balancer group again and check for behavior. func (s) TestBalancerGroup_start_close(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) gator := weightedaggregator.New(cc, nil, testutils.NewTestWRR) gator.Start() bg := New(Options{ @@ -176,7 +176,7 @@ func (s) TestBalancerGroup_start_close_deadlock(t *testing.T) { stub.Register(balancerName, stub.BalancerFuncs{}) builder := balancer.Get(balancerName) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) gator := weightedaggregator.New(cc, nil, testutils.NewTestWRR) gator.Start() bg := New(Options{ @@ -203,8 +203,8 @@ func (s) TestBalancerGroup_start_close_deadlock(t *testing.T) { // Two rr balancers are added to bg, each with 2 ready subConns. A sub-balancer // is removed later, so the balancer group returned has one sub-balancer in its // own map, and one sub-balancer in cache. -func initBalancerGroupForCachingTest(t *testing.T, idleCacheTimeout time.Duration) (*weightedaggregator.Aggregator, *BalancerGroup, *testutils.TestClientConn, map[resolver.Address]*testutils.TestSubConn) { - cc := testutils.NewTestClientConn(t) +func initBalancerGroupForCachingTest(t *testing.T, idleCacheTimeout time.Duration) (*weightedaggregator.Aggregator, *BalancerGroup, *testutils.BalancerClientConn, map[resolver.Address]*testutils.TestSubConn) { + cc := testutils.NewBalancerClientConn(t) gator := weightedaggregator.New(cc, nil, testutils.NewTestWRR) gator.Start() bg := New(Options{ @@ -503,7 +503,7 @@ func (s) TestBalancerGroupBuildOptions(t *testing.T) { return nil }, }) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bg := New(Options{ CC: cc, BuildOpts: bOpts, @@ -531,7 +531,7 @@ func (s) TestBalancerExitIdleOne(t *testing.T) { exitIdleCh <- struct{}{} }, }) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bg := New(Options{ CC: cc, BuildOpts: balancer.BuildOptions{}, @@ -561,7 +561,7 @@ func (s) TestBalancerExitIdleOne(t *testing.T) { // for the second passed in address and also only picks that created SubConn. // The new aggregated picker should reflect this change for the child. func (s) TestBalancerGracefulSwitch(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) gator := weightedaggregator.New(cc, nil, testutils.NewTestWRR) gator.Start() bg := New(Options{ diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index 99e1e5b36c89..b66dcb213276 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -23,7 +23,6 @@ package dns import ( "context" "encoding/json" - "errors" "fmt" "net" "os" @@ -37,6 +36,7 @@ import ( "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/resolver/dns/internal" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -47,15 +47,11 @@ var EnableSRVLookups = false var logger = grpclog.Component("dns") -// Globals to stub out in tests. TODO: Perhaps these two can be combined into a -// single variable for testing the resolver? -var ( - newTimer = time.NewTimer - newTimerDNSResRate = time.NewTimer -) - func init() { resolver.Register(NewBuilder()) + internal.TimeAfterFunc = time.After + internal.NewNetResolver = newNetResolver + internal.AddressDialer = addressDialer } const ( @@ -70,23 +66,6 @@ const ( txtAttribute = "grpc_config=" ) -var ( - errMissingAddr = errors.New("dns resolver: missing address") - - // Addresses ending with a colon that is supposed to be the separator - // between host and port is not allowed. E.g. "::" is a valid address as - // it is an IPv6 address (host only) and "[::]:" is invalid as it ends with - // a colon as the host and port separator - errEndsWithColon = errors.New("dns resolver: missing port after port-separator colon") -) - -var ( - defaultResolver netResolver = net.DefaultResolver - // To prevent excessive re-resolution, we enforce a rate limit on DNS - // resolution requests. - minDNSResRate = 30 * time.Second -) - var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) { return func(ctx context.Context, network, _ string) (net.Conn, error) { var dialer net.Dialer @@ -94,7 +73,11 @@ var addressDialer = func(address string) func(context.Context, string, string) ( } } -var newNetResolver = func(authority string) (netResolver, error) { +var newNetResolver = func(authority string) (internal.NetResolver, error) { + if authority == "" { + return net.DefaultResolver, nil + } + host, port, err := parseTarget(authority, defaultDNSSvrPort) if err != nil { return nil, err @@ -104,7 +87,7 @@ var newNetResolver = func(authority string) (netResolver, error) { return &net.Resolver{ PreferGo: true, - Dial: addressDialer(authorityWithPort), + Dial: internal.AddressDialer(authorityWithPort), }, nil } @@ -142,13 +125,9 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts disableServiceConfig: opts.DisableServiceConfig, } - if target.URL.Host == "" { - d.resolver = defaultResolver - } else { - d.resolver, err = newNetResolver(target.URL.Host) - if err != nil { - return nil, err - } + d.resolver, err = internal.NewNetResolver(target.URL.Host) + if err != nil { + return nil, err } d.wg.Add(1) @@ -161,12 +140,6 @@ func (b *dnsBuilder) Scheme() string { return "dns" } -type netResolver interface { - LookupHost(ctx context.Context, host string) (addrs []string, err error) - LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) - LookupTXT(ctx context.Context, name string) (txts []string, err error) -} - // deadResolver is a resolver that does nothing. type deadResolver struct{} @@ -178,7 +151,7 @@ func (deadResolver) Close() {} type dnsResolver struct { host string port string - resolver netResolver + resolver internal.NetResolver ctx context.Context cancel context.CancelFunc cc resolver.ClientConn @@ -223,29 +196,27 @@ func (d *dnsResolver) watcher() { err = d.cc.UpdateState(*state) } - var timer *time.Timer + var waitTime time.Duration if err == nil { // Success resolving, wait for the next ResolveNow. However, also wait 30 // seconds at the very least to prevent constantly re-resolving. backoffIndex = 1 - timer = newTimerDNSResRate(minDNSResRate) + waitTime = internal.MinResolutionRate select { case <-d.ctx.Done(): - timer.Stop() return case <-d.rn: } } else { // Poll on an error found in DNS Resolver or an error received from // ClientConn. - timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex)) + waitTime = backoff.DefaultExponential.Backoff(backoffIndex) backoffIndex++ } select { case <-d.ctx.Done(): - timer.Stop() return - case <-timer.C: + case <-internal.TimeAfterFunc(waitTime): } } } @@ -387,7 +358,7 @@ func formatIP(addr string) (addrIP string, ok bool) { // target: ":80" defaultPort: "443" returns host: "localhost", port: "80" func parseTarget(target, defaultPort string) (host, port string, err error) { if target == "" { - return "", "", errMissingAddr + return "", "", internal.ErrMissingAddr } if ip := net.ParseIP(target); ip != nil { // target is an IPv4 or IPv6(without brackets) address @@ -397,7 +368,7 @@ func parseTarget(target, defaultPort string) (host, port string, err error) { if port == "" { // If the port field is empty (target ends with colon), e.g. "[::1]:", // this is an error. - return "", "", errEndsWithColon + return "", "", internal.ErrEndsWithColon } // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port if host == "" { diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index a66ffffd3ce1..1244edcb61cf 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -16,17 +16,15 @@ * */ -package dns +package dns_test import ( "context" "errors" "fmt" "net" - "os" - "reflect" "strings" - "sync" + "sync/atomic" "testing" "time" @@ -34,381 +32,168 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/balancer" grpclbstate "google.golang.org/grpc/balancer/grpclb/state" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" - "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/resolver/dns" + dnsinternal "google.golang.org/grpc/internal/resolver/dns/internal" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" -) -func TestMain(m *testing.M) { - // Set a non-zero duration only for tests which are actually testing that - // feature. - replaceDNSResRate(time.Duration(0)) // No nead to clean up since we os.Exit - overrideDefaultResolver(false) // No nead to clean up since we os.Exit - code := m.Run() - os.Exit(code) -} + _ "google.golang.org/grpc" // To initialize internal.ParseServiceConfig +) const ( txtBytesLimit = 255 defaultTestTimeout = 10 * time.Second defaultTestShortTimeout = 10 * time.Millisecond -) - -type testClientConn struct { - resolver.ClientConn // For unimplemented functions - target string - m1 sync.Mutex - state resolver.State - updateStateCalls int - errChan chan error - updateStateErr error -} - -func (t *testClientConn) UpdateState(s resolver.State) error { - t.m1.Lock() - defer t.m1.Unlock() - t.state = s - t.updateStateCalls++ - // This error determines whether DNS Resolver actually decides to exponentially backoff or not. - // This can be any error. - return t.updateStateErr -} -func (t *testClientConn) getState() (resolver.State, int) { - t.m1.Lock() - defer t.m1.Unlock() - return t.state, t.updateStateCalls -} - -func scFromState(s resolver.State) string { - if s.ServiceConfig != nil { - if s.ServiceConfig.Err != nil { - return "" - } - return s.ServiceConfig.Config.(unparsedServiceConfig).config - } - return "" -} + colonDefaultPort = ":443" +) -type unparsedServiceConfig struct { - serviceconfig.Config - config string +type s struct { + grpctest.Tester } -func (t *testClientConn) ParseServiceConfig(s string) *serviceconfig.ParseResult { - return &serviceconfig.ParseResult{Config: unparsedServiceConfig{config: s}} +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) } -func (t *testClientConn) ReportError(err error) { - t.errChan <- err +// Override the default net.Resolver with a test resolver. +func overrideNetResolver(t *testing.T, r *testNetResolver) { + origNetResolver := dnsinternal.NewNetResolver + dnsinternal.NewNetResolver = func(string) (dnsinternal.NetResolver, error) { return r, nil } + t.Cleanup(func() { dnsinternal.NewNetResolver = origNetResolver }) } -type testResolver struct { - // A write to this channel is made when this resolver receives a resolution - // request. Tests can rely on reading from this channel to be notified about - // resolution requests instead of sleeping for a predefined period of time. - lookupHostCh *testutils.Channel +// Override the DNS Min Res Rate used by the resolver. +func overrideResolutionRate(t *testing.T, d time.Duration) { + origMinResRate := dnsinternal.MinResolutionRate + dnsinternal.MinResolutionRate = d + t.Cleanup(func() { dnsinternal.MinResolutionRate = origMinResRate }) } -func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { - if tr.lookupHostCh != nil { - tr.lookupHostCh.Send(nil) +// Override the timer used by the DNS resolver to fire after a duration of d. +func overrideTimeAfterFunc(t *testing.T, d time.Duration) { + origTimeAfter := dnsinternal.TimeAfterFunc + dnsinternal.TimeAfterFunc = func(time.Duration) <-chan time.Time { + return time.After(d) } - return hostLookup(host) + t.Cleanup(func() { dnsinternal.TimeAfterFunc = origTimeAfter }) } -func (*testResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { - return srvLookup(service, proto, name) +// Override the timer used by the DNS resolver as follows: +// - use the durChan to read the duration that the resolver wants to wait for +// - use the timerChan to unblock the wait on the timer +func overrideTimeAfterFuncWithChannel(t *testing.T) (durChan chan time.Duration, timeChan chan time.Time) { + origTimeAfter := dnsinternal.TimeAfterFunc + durChan = make(chan time.Duration, 1) + timeChan = make(chan time.Time) + dnsinternal.TimeAfterFunc = func(d time.Duration) <-chan time.Time { + select { + case durChan <- d: + default: + } + return timeChan + } + t.Cleanup(func() { dnsinternal.TimeAfterFunc = origTimeAfter }) + return durChan, timeChan } -func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, error) { - return txtLookup(host) +func enableSRVLookups(t *testing.T) { + origEnableSRVLookups := dns.EnableSRVLookups + dns.EnableSRVLookups = true + t.Cleanup(func() { dns.EnableSRVLookups = origEnableSRVLookups }) } -// overrideDefaultResolver overrides the defaultResolver used by the code with -// an instance of the testResolver. pushOnLookup controls whether the -// testResolver created here pushes lookupHost events on its channel. -func overrideDefaultResolver(pushOnLookup bool) func() { - oldResolver := defaultResolver +// Builds a DNS resolver for target and returns a couple of channels to read the +// state and error pushed by the resolver respectively. +func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Resolver, chan resolver.State, chan error) { + t.Helper() - var lookupHostCh *testutils.Channel - if pushOnLookup { - lookupHostCh = testutils.NewChannel() + b := resolver.Get("dns") + if b == nil { + t.Fatalf("Resolver for dns:/// scheme not registered") } - defaultResolver = &testResolver{lookupHostCh: lookupHostCh} - return func() { - defaultResolver = oldResolver + stateCh := make(chan resolver.State, 1) + updateStateF := func(s resolver.State) error { + select { + case stateCh <- s: + default: + } + return nil } -} -func replaceDNSResRate(d time.Duration) func() { - oldMinDNSResRate := minDNSResRate - minDNSResRate = d - - return func() { - minDNSResRate = oldMinDNSResRate + errCh := make(chan error, 1) + reportErrorF := func(err error) { + select { + case errCh <- err: + default: + } } -} - -var hostLookupTbl = struct { - sync.Mutex - tbl map[string][]string -}{ - tbl: map[string][]string{ - "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, - "ipv4.single.fake": {"1.2.3.4"}, - "srv.ipv4.single.fake": {"2.4.6.8"}, - "srv.ipv4.multi.fake": {}, - "srv.ipv6.single.fake": {}, - "srv.ipv6.multi.fake": {}, - "ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"}, - "ipv6.single.fake": {"2607:f8b0:400a:801::1001"}, - "ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"}, - }, -} -func hostLookup(host string) ([]string, error) { - hostLookupTbl.Lock() - defer hostLookupTbl.Unlock() - if addrs, ok := hostLookupTbl.tbl[host]; ok { - return addrs, nil - } - return nil, &net.DNSError{ - Err: "hostLookup error", - Name: host, - Server: "fake", - IsTemporary: true, + tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF, ReportErrorF: reportErrorF} + r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("Failed to build DNS resolver for target %q: %v\n", target, err) } -} + t.Cleanup(func() { r.Close() }) -var srvLookupTbl = struct { - sync.Mutex - tbl map[string][]*net.SRV -}{ - tbl: map[string][]*net.SRV{ - "_grpclb._tcp.srv.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}}, - "_grpclb._tcp.srv.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}}, - "_grpclb._tcp.srv.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}}, - "_grpclb._tcp.srv.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}}, - }, + return r, stateCh, errCh } -func srvLookup(service, proto, name string) (string, []*net.SRV, error) { - cname := "_" + service + "._" + proto + "." + name - srvLookupTbl.Lock() - defer srvLookupTbl.Unlock() - if srvs, cnt := srvLookupTbl.tbl[cname]; cnt { - return cname, srvs, nil - } - return "", nil, &net.DNSError{ - Err: "srvLookup error", - Name: cname, - Server: "fake", - IsTemporary: true, +// Waits for a state update from the DNS resolver and verifies the following: +// - wantAddrs matches the list of addresses in the update +// - wantBalancerAddrs matches the list of grpclb addresses in the update +// - wantSC matches the service config in the update +func verifyUpdateFromResolver(ctx context.Context, t *testing.T, stateCh chan resolver.State, wantAddrs, wantBalancerAddrs []resolver.Address, wantSC string) { + t.Helper() + + var state resolver.State + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for a state update from the resolver") + case state = <-stateCh: } -} -// scfs contains an array of service config file string in JSON format. -// Notes about the scfs contents and usage: -// scfs contains 4 service config file JSON strings for testing. Inside each -// service config file, there are multiple choices. scfs[0:3] each contains 5 -// choices, and first 3 choices are nonmatching choices based on canarying rule, -// while the last two are matched choices. scfs[3] only contains 3 choices, and -// all of them are nonmatching based on canarying rule. For each of scfs[0:3], -// the eventually returned service config, which is from the first of the two -// matched choices, is stored in the corresponding scs element (e.g. -// scfs[0]->scs[0]). scfs and scs elements are used in pair to test the dns -// resolver functionality, with scfs as the input and scs used for validation of -// the output. For scfs[3], it corresponds to empty service config, since there -// isn't a matched choice. -var scfs = []string{ - `[ - { - "clientLanguage": [ - "CPP", - "JAVA" - ], - "serviceConfig": { - "loadBalancingPolicy": "grpclb", - "methodConfig": [ - { - "name": [ - { - "service": "all" - } - ], - "timeout": "1s" - } - ] - } - }, - { - "percentage": 0, - "serviceConfig": { - "loadBalancingPolicy": "grpclb", - "methodConfig": [ - { - "name": [ - { - "service": "all" - } - ], - "timeout": "1s" - } - ] - } - }, - { - "clientHostName": [ - "localhost" - ], - "serviceConfig": { - "loadBalancingPolicy": "grpclb", - "methodConfig": [ - { - "name": [ - { - "service": "all" - } - ], - "timeout": "1s" - } - ] - } - }, - { - "clientLanguage": [ - "GO" - ], - "percentage": 100, - "serviceConfig": { - "methodConfig": [ - { - "name": [ - { - "method": "bar" - } - ], - "maxRequestMessageBytes": 1024, - "maxResponseMessageBytes": 1024 - } - ] - } - }, - { - "serviceConfig": { - "loadBalancingPolicy": "round_robin", - "methodConfig": [ - { - "name": [ - { - "service": "foo", - "method": "bar" - } - ], - "waitForReady": true - } - ] - } + if !cmp.Equal(state.Addresses, wantAddrs, cmpopts.EquateEmpty()) { + t.Fatalf("Got addresses: %+v, want: %+v", state.Addresses, wantAddrs) } -]`, - `[ - { - "clientLanguage": [ - "CPP", - "JAVA" - ], - "serviceConfig": { - "loadBalancingPolicy": "grpclb", - "methodConfig": [ - { - "name": [ - { - "service": "all" - } - ], - "timeout": "1s" - } - ] - } - }, - { - "percentage": 0, - "serviceConfig": { - "loadBalancingPolicy": "grpclb", - "methodConfig": [ - { - "name": [ - { - "service": "all" - } - ], - "timeout": "1s" - } - ] + if gs := grpclbstate.Get(state); gs == nil { + if len(wantBalancerAddrs) > 0 { + t.Fatalf("Got no grpclb addresses. Want %d", len(wantBalancerAddrs)) } - }, - { - "clientHostName": [ - "localhost" - ], - "serviceConfig": { - "loadBalancingPolicy": "grpclb", - "methodConfig": [ - { - "name": [ - { - "service": "all" - } - ], - "timeout": "1s" - } - ] + } else { + if !cmp.Equal(gs.BalancerAddresses, wantBalancerAddrs) { + t.Fatalf("Got grpclb addresses %+v, want %+v", gs.BalancerAddresses, wantBalancerAddrs) } - }, - { - "clientLanguage": [ - "GO" - ], - "percentage": 100, - "serviceConfig": { - "methodConfig": [ - { - "name": [ - { - "service": "foo", - "method": "bar" - } - ], - "waitForReady": true, - "timeout": "1s", - "maxRequestMessageBytes": 1024, - "maxResponseMessageBytes": 1024 - } - ] + } + if wantSC == "{}" { + if state.ServiceConfig != nil && state.ServiceConfig.Config != nil { + t.Fatalf("Got service config:\n%s \nWant service config: {}", cmp.Diff(nil, state.ServiceConfig.Config)) } - }, - { - "serviceConfig": { - "loadBalancingPolicy": "round_robin", - "methodConfig": [ - { - "name": [ - { - "service": "foo", - "method": "bar" - } - ], - "waitForReady": true - } - ] + + } else if wantSC != "" { + wantSCParsed := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(wantSC) + if !internal.EqualServiceConfigForTesting(state.ServiceConfig.Config, wantSCParsed.Config) { + t.Fatalf("Got service config:\n%s \nWant service config:\n%s", cmp.Diff(nil, state.ServiceConfig.Config), cmp.Diff(nil, wantSCParsed.Config)) } } -]`, - `[ +} + +// This is the service config used by the fake net.Resolver in its TXT record. +// - it contains an array of 5 entries +// - the first three will be dropped by the DNS resolver as part of its +// canarying rule matching functionality: +// - the client language does not match in the first entry +// - the percentage is set to 0 in the second entry +// - the client host name does not match in the third entry +// - the fourth and fifth entries will match the canarying rules, and therefore +// the fourth entry will be used as it will be the first matching entry. +const txtRecordGood = ` +[ { "clientLanguage": [ "CPP", @@ -506,8 +291,38 @@ var scfs = []string{ ] } } -]`, - `[ +]` + +// This is the matched portion of the above TXT record entry. +const scJSON = ` +{ + "loadBalancingPolicy": "round_robin", + "methodConfig": [ + { + "name": [ + { + "service": "foo" + } + ], + "waitForReady": true, + "timeout": "1s" + }, + { + "name": [ + { + "service": "bar" + } + ], + "waitForReady": false + } + ] +}` + +// This service config contains three entries, but none of the match the DNS +// resolver's canarying rules and hence the resulting service config pushed by +// the DNS resolver will be an empty one. +const txtRecordNonMatching = ` +[ { "clientLanguage": [ "CPP", @@ -561,904 +376,706 @@ var scfs = []string{ ] } } -]`, -} - -// scs contains an array of service config string in JSON format. -var scs = []string{ - `{ - "methodConfig": [ - { - "name": [ - { - "method": "bar" - } - ], - "maxRequestMessageBytes": 1024, - "maxResponseMessageBytes": 1024 - } - ] - }`, - `{ - "methodConfig": [ - { - "name": [ - { - "service": "foo", - "method": "bar" - } - ], - "waitForReady": true, - "timeout": "1s", - "maxRequestMessageBytes": 1024, - "maxResponseMessageBytes": 1024 - } - ] - }`, - `{ - "loadBalancingPolicy": "round_robin", - "methodConfig": [ - { - "name": [ - { - "service": "foo" - } - ], - "waitForReady": true, - "timeout": "1s" - }, - { - "name": [ - { - "service": "bar" - } - ], - "waitForReady": false - } - ] - }`, -} +]` -// scLookupTbl is a map, which contains targets that have service config to -// their configs. Targets not in this set should not have service config. -var scLookupTbl = map[string]string{ - "foo.bar.com": scs[0], - "srv.ipv4.single.fake": scs[1], - "srv.ipv4.multi.fake": scs[2], -} - -// generateSC returns a service config string in JSON format for the input name. -func generateSC(name string) string { - return scLookupTbl[name] -} - -// generateSCF generates a slice of strings (aggregately representing a single -// service config file) for the input config string, which mocks the result -// from a real DNS TXT record lookup. -func generateSCF(cfg string) []string { - b := append([]byte(txtAttribute), []byte(cfg)...) - - // Split b into multiple strings, each with a max of 255 bytes, which is - // the DNS TXT record limit. - var r []string - for i := 0; i < len(b); i += txtBytesLimit { - if i+txtBytesLimit > len(b) { - r = append(r, string(b[i:])) - } else { - r = append(r, string(b[i:i+txtBytesLimit])) - } - } - return r -} - -var txtLookupTbl = struct { - sync.Mutex - tbl map[string][]string -}{ - tbl: map[string][]string{ - txtPrefix + "foo.bar.com": generateSCF(scfs[0]), - txtPrefix + "srv.ipv4.single.fake": generateSCF(scfs[1]), - txtPrefix + "srv.ipv4.multi.fake": generateSCF(scfs[2]), - txtPrefix + "srv.ipv6.single.fake": generateSCF(scfs[3]), - txtPrefix + "srv.ipv6.multi.fake": generateSCF(scfs[3]), - }, -} - -func txtLookup(host string) ([]string, error) { - txtLookupTbl.Lock() - defer txtLookupTbl.Unlock() - if scs, cnt := txtLookupTbl.tbl[host]; cnt { - return scs, nil - } - return nil, &net.DNSError{ - Err: "txtLookup error", - Name: host, - Server: "fake", - IsTemporary: true, - } -} - -func TestResolve(t *testing.T) { - testDNSResolver(t) - testDNSResolverWithSRV(t) - testDNSResolveNow(t) - testIPResolver(t) -} - -func testDNSResolver(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(_ time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } +// Tests the scenario where a name resolves to a list of addresses, possibly +// some grpclb addresses as well, and a service config. The test verifies that +// the expected update is pushed to the channel. +func (s) TestDNSResolver_Basic(t *testing.T) { tests := []struct { - target string - addrWant []resolver.Address - scWant string + name string + target string + hostLookupTable map[string][]string + srvLookupTable map[string][]*net.SRV + txtLookupTable map[string][]string + wantAddrs []resolver.Address + wantBalancerAddrs []resolver.Address + wantSC string }{ { - "foo.bar.com", - []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, - generateSC("foo.bar.com"), + name: "default_port", + target: "foo.bar.com", + hostLookupTable: map[string][]string{ + "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, + wantBalancerAddrs: nil, + wantSC: scJSON, }, { - "foo.bar.com:1234", - []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, - generateSC("foo.bar.com"), + name: "specified_port", + target: "foo.bar.com:1234", + hostLookupTable: map[string][]string{ + "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, + wantBalancerAddrs: nil, + wantSC: scJSON, }, { - "srv.ipv4.single.fake", - []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, - generateSC("srv.ipv4.single.fake"), + name: "ipv4_with_SRV_and_single_grpclb_address", + target: "srv.ipv4.single.fake", + hostLookupTable: map[string][]string{ + "srv.ipv4.single.fake": {"2.4.6.8"}, + "ipv4.single.fake": {"1.2.3.4"}, + }, + srvLookupTable: map[string][]*net.SRV{ + "_grpclb._tcp.srv.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.srv.ipv4.single.fake": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, + wantBalancerAddrs: []resolver.Address{{Addr: "1.2.3.4:1234", ServerName: "ipv4.single.fake"}}, + wantSC: scJSON, }, { - "srv.ipv4.multi.fake", - nil, - generateSC("srv.ipv4.multi.fake"), + name: "ipv4_with_SRV_and_multiple_grpclb_address", + target: "srv.ipv4.multi.fake", + hostLookupTable: map[string][]string{ + "ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"}, + }, + srvLookupTable: map[string][]*net.SRV{ + "_grpclb._tcp.srv.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.srv.ipv4.multi.fake": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: nil, + wantBalancerAddrs: []resolver.Address{ + {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"}, + {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"}, + {Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"}, + }, + wantSC: scJSON, }, { - "srv.ipv6.single.fake", - nil, - generateSC("srv.ipv6.single.fake"), + name: "ipv6_with_SRV_and_single_grpclb_address", + target: "srv.ipv6.single.fake", + hostLookupTable: map[string][]string{ + "srv.ipv6.single.fake": nil, + "ipv6.single.fake": {"2607:f8b0:400a:801::1001"}, + }, + srvLookupTable: map[string][]*net.SRV{ + "_grpclb._tcp.srv.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.srv.ipv6.single.fake": txtRecordServiceConfig(txtRecordNonMatching), + }, + wantAddrs: nil, + wantBalancerAddrs: []resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}}, + wantSC: "{}", }, { - "srv.ipv6.multi.fake", - nil, - generateSC("srv.ipv6.multi.fake"), + name: "ipv6_with_SRV_and_multiple_grpclb_address", + target: "srv.ipv6.multi.fake", + hostLookupTable: map[string][]string{ + "srv.ipv6.multi.fake": nil, + "ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"}, + }, + srvLookupTable: map[string][]*net.SRV{ + "_grpclb._tcp.srv.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.srv.ipv6.multi.fake": txtRecordServiceConfig(txtRecordNonMatching), + }, + wantAddrs: nil, + wantBalancerAddrs: []resolver.Address{ + {Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.multi.fake"}, + {Addr: "[2607:f8b0:400a:801::1002]:1234", ServerName: "ipv6.multi.fake"}, + {Addr: "[2607:f8b0:400a:801::1003]:1234", ServerName: "ipv6.multi.fake"}, + }, + wantSC: "{}", }, } - for _, a := range tests { - b := NewBuilder() - cc := &testClientConn{target: a.target} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("%v\n", err) - } - var state resolver.State - var cnt int - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt > 0 { - break - } - time.Sleep(time.Millisecond) - } - if cnt == 0 { - t.Fatalf("UpdateState not called after 2s; aborting") - } - if !cmp.Equal(a.addrWant, state.Addresses, cmpopts.EquateEmpty()) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) - } - sc := scFromState(state) - if a.scWant != sc { - t.Errorf("Resolved service config of target: %q = %+v, want %+v", a.target, sc, a.scWant) - } - r.Close() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + overrideTimeAfterFunc(t, 2*defaultTestTimeout) + overrideNetResolver(t, &testNetResolver{ + hostLookupTable: test.hostLookupTable, + srvLookupTable: test.srvLookupTable, + txtLookupTable: test.txtLookupTable, + }) + enableSRVLookups(t) + _, stateCh, _ := buildResolverWithTestClientConn(t, test.target) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + verifyUpdateFromResolver(ctx, t, stateCh, test.wantAddrs, test.wantBalancerAddrs, test.wantSC) + }) } } -// DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't -// send back an error from updating the DNS Resolver's state. -func TestDNSResolverExponentialBackoff(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - timerChan := testutils.NewChannel() - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, allows this test to call timer immediately. - t := time.NewTimer(time.Hour) - timerChan.Send(t) - return t - } +// Tests the case where the channel returns an error for the update pushed by +// the DNS resolver. Verifies that the DNS resolver backs off before trying to +// resolve. Once the channel returns a nil error, the test verifies that the DNS +// resolver does not backoff anymore. +func (s) TestDNSResolver_ExponentialBackoff(t *testing.T) { tests := []struct { - name string - target string - addrWant []resolver.Address - scWant string + name string + target string + hostLookupTable map[string][]string + txtLookupTable map[string][]string + wantAddrs []resolver.Address + wantSC string }{ { - "happy case default port", - "foo.bar.com", - []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, - generateSC("foo.bar.com"), + name: "happy case default port", + target: "foo.bar.com", + hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}}, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, + wantSC: scJSON, }, { - "happy case specified port", - "foo.bar.com:1234", - []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, - generateSC("foo.bar.com"), + name: "happy case specified port", + target: "foo.bar.com:1234", + hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}}, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, + wantSC: scJSON, }, { - "happy case another default port", - "srv.ipv4.single.fake", - []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, - generateSC("srv.ipv4.single.fake"), + name: "happy case another default port", + target: "srv.ipv4.single.fake", + hostLookupTable: map[string][]string{ + "srv.ipv4.single.fake": {"2.4.6.8"}, + "ipv4.single.fake": {"1.2.3.4"}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.srv.ipv4.single.fake": txtRecordServiceConfig(txtRecordGood), + }, + wantAddrs: []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, + wantSC: scJSON, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - b := NewBuilder() - cc := &testClientConn{target: test.target} - // Cause ClientConn to return an error. - cc.updateStateErr = balancer.ErrBadResolverState - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", test.target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("Error building resolver for target %v: %v", test.target, err) - } - var state resolver.State - var cnt int - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt > 0 { - break - } - time.Sleep(time.Millisecond) - } - if cnt == 0 { - t.Fatalf("UpdateState not called after 2s; aborting") - } - if !reflect.DeepEqual(test.addrWant, state.Addresses) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", test.target, state.Addresses, test.addrWant) - } - sc := scFromState(state) - if test.scWant != sc { - t.Errorf("Resolved service config of target: %q = %+v, want %+v", test.target, sc, test.scWant) - } - ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer ctxCancel() - // Cause timer to go off 10 times, and see if it calls updateState() correctly. - for i := 0; i < 10; i++ { - timer, err := timerChan.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + durChan, timeChan := overrideTimeAfterFuncWithChannel(t) + overrideNetResolver(t, &testNetResolver{ + hostLookupTable: test.hostLookupTable, + txtLookupTable: test.txtLookupTable, + }) + + // Set the test clientconn to return error back to the resolver when + // it pushes an update on the channel. + var returnNilErr atomic.Bool + updateStateF := func(s resolver.State) error { + if returnNilErr.Load() { + return nil } - timerPointer := timer.(*time.Timer) - timerPointer.Reset(0) + return balancer.ErrBadResolverState } - // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call - // ClientConn update state. - deadline := time.Now().Add(defaultTestTimeout) - for { - cc.m1.Lock() - got := cc.updateStateCalls - cc.m1.Unlock() - if got == 11 { - break - } - - if time.Now().After(deadline) { - t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got) - } + tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF} - time.Sleep(time.Millisecond) + b := resolver.Get("dns") + if b == nil { + t.Fatalf("Resolver for dns:/// scheme not registered") } - - // Update resolver.ClientConn to not return an error anymore - this should stop it from backing off. - cc.updateStateErr = nil - timer, err := timerChan.Receive(ctx) + r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", test.target))}, tcc, resolver.BuildOptions{}) if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + t.Fatalf("Failed to build DNS resolver for target %q: %v\n", test.target, err) } - timerPointer := timer.(*time.Timer) - timerPointer.Reset(0) - // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call - // ClientConn update state the final time. The DNS Resolver should then stop polling. - deadline = time.Now().Add(defaultTestTimeout) - for { - cc.m1.Lock() - got := cc.updateStateCalls - cc.m1.Unlock() - if got == 12 { - break + defer r.Close() + + // Expect the DNS resolver to backoff and attempt to re-resolve. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + const retries = 10 + var prevDur time.Duration + for i := 0; i < retries; i++ { + select { + case <-ctx.Done(): + t.Fatalf("(Iteration: %d): Timeout when waiting for DNS resolver to backoff", i) + case dur := <-durChan: + if dur <= prevDur { + t.Fatalf("(Iteration: %d): Unexpected decrease in amount of time to backoff", i) + } } - if time.Now().After(deadline) { - t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got) - } + // Unblock the DNS resolver's backoff by pushing the current time. + timeChan <- time.Now() + } - _, err := timerChan.ReceiveOrFail() - if err { - t.Fatalf("Should not poll again after Client Conn stops returning error.") - } + // Update resolver.ClientConn to not return an error anymore. + returnNilErr.Store(true) - time.Sleep(time.Millisecond) + // Unblock the DNS resolver's backoff, if ongoing, while we set the + // test clientConn to not return an error anymore. + select { + case timeChan <- time.Now(): + default: + } + + // Verify that the DNS resolver does not backoff anymore. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + select { + case <-durChan: + t.Fatal("Unexpected DNS resolver backoff") + case <-sCtx.Done(): } - r.Close() }) } } -func testDNSResolverWithSRV(t *testing.T) { - EnableSRVLookups = true - defer func() { - EnableSRVLookups = false - }() - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(_ time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) +// Tests the case where the DNS resolver is asked to re-resolve by invoking the +// ResolveNow method. +func (s) TestDNSResolver_ResolveNow(t *testing.T) { + const target = "foo.bar.com" + + overrideResolutionRate(t, 0) + overrideTimeAfterFunc(t, 0) + tr := &testNetResolver{ + hostLookupTable: map[string][]string{ + "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, } + overrideNetResolver(t, tr) + + r, stateCh, _ := buildResolverWithTestClientConn(t, target) + + // Verify that the first update pushed by the resolver matches expectations. + wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} + wantSC := scJSON + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + verifyUpdateFromResolver(ctx, t, stateCh, wantAddrs, nil, wantSC) + + // Update state in the fake net.Resolver to return only one address and a + // new service config. + tr.UpdateHostLookupTable(map[string][]string{target: {"1.2.3.4"}}) + tr.UpdateTXTLookupTable(map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(`[{"serviceConfig":{"loadBalancingPolicy": "grpclb"}}]`), + }) + + // Ask the resolver to re-resolve and verify that the new update matches + // expectations. + r.ResolveNow(resolver.ResolveNowOptions{}) + wantAddrs = []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}} + wantSC = `{"loadBalancingPolicy": "grpclb"}` + verifyUpdateFromResolver(ctx, t, stateCh, wantAddrs, nil, wantSC) + + // Update state in the fake resolver to return no addresses and the same + // service config as before. + tr.UpdateHostLookupTable(map[string][]string{target: nil}) + + // Ask the resolver to re-resolve and verify that the new update matches + // expectations. + r.ResolveNow(resolver.ResolveNowOptions{}) + verifyUpdateFromResolver(ctx, t, stateCh, nil, nil, wantSC) +} + +// Tests the case where the given name is an IP address and verifies that the +// update pushed by the DNS resolver meets expectations. +func (s) TestIPResolver(t *testing.T) { tests := []struct { - target string - addrWant []resolver.Address - grpclbAddrs []resolver.Address - scWant string + name string + target string + wantAddr []resolver.Address }{ { - "foo.bar.com", - []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, - nil, - generateSC("foo.bar.com"), + name: "localhost ipv4 default port", + target: "127.0.0.1", + wantAddr: []resolver.Address{{Addr: "127.0.0.1:443"}}, }, { - "foo.bar.com:1234", - []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, - nil, - generateSC("foo.bar.com"), + name: "localhost ipv4 non-default port", + target: "127.0.0.1:12345", + wantAddr: []resolver.Address{{Addr: "127.0.0.1:12345"}}, }, { - "srv.ipv4.single.fake", - []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, - []resolver.Address{{Addr: "1.2.3.4:1234", ServerName: "ipv4.single.fake"}}, - generateSC("srv.ipv4.single.fake"), + name: "localhost ipv6 default port no brackets", + target: "::1", + wantAddr: []resolver.Address{{Addr: "[::1]:443"}}, }, { - "srv.ipv4.multi.fake", - nil, - []resolver.Address{ - {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"}, - {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"}, - {Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"}, - }, - generateSC("srv.ipv4.multi.fake"), + name: "localhost ipv6 default port with brackets", + target: "[::1]", + wantAddr: []resolver.Address{{Addr: "[::1]:443"}}, }, { - "srv.ipv6.single.fake", - nil, - []resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}}, - generateSC("srv.ipv6.single.fake"), + name: "localhost ipv6 non-default port", + target: "[::1]:12345", + wantAddr: []resolver.Address{{Addr: "[::1]:12345"}}, }, { - "srv.ipv6.multi.fake", - nil, - []resolver.Address{ - {Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.multi.fake"}, - {Addr: "[2607:f8b0:400a:801::1002]:1234", ServerName: "ipv6.multi.fake"}, - {Addr: "[2607:f8b0:400a:801::1003]:1234", ServerName: "ipv6.multi.fake"}, - }, - generateSC("srv.ipv6.multi.fake"), + name: "ipv6 default port no brackets", + target: "2001:db8:85a3::8a2e:370:7334", + wantAddr: []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:443"}}, + }, + { + name: "ipv6 default port with brackets", + target: "[2001:db8:85a3::8a2e:370:7334]", + wantAddr: []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:443"}}, + }, + { + name: "ipv6 non-default port with brackets", + target: "[2001:db8:85a3::8a2e:370:7334]:12345", + wantAddr: []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:12345"}}, + }, + { + name: "abbreviated ipv6 address", + target: "[2001:db8::1]:http", + wantAddr: []resolver.Address{{Addr: "[2001:db8::1]:http"}}, }, + // TODO(yuxuanli): zone support? } - for _, a := range tests { - b := NewBuilder() - cc := &testClientConn{target: a.target} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("%v\n", err) - } - defer r.Close() - var state resolver.State - var cnt int - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt > 0 { - break + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + overrideResolutionRate(t, 0) + overrideTimeAfterFunc(t, 2*defaultTestTimeout) + r, stateCh, _ := buildResolverWithTestClientConn(t, test.target) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + verifyUpdateFromResolver(ctx, t, stateCh, test.wantAddr, nil, "") + + // Attempt to re-resolve should not result in a state update. + r.ResolveNow(resolver.ResolveNowOptions{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + select { + case <-sCtx.Done(): + case s := <-stateCh: + t.Fatalf("Unexpected state update from the resolver: %+v", s) } - time.Sleep(time.Millisecond) - } - if cnt == 0 { - t.Fatalf("UpdateState not called after 2s; aborting") - } - if !cmp.Equal(a.addrWant, state.Addresses, cmpopts.EquateEmpty()) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) - } - gs := grpclbstate.Get(state) - if (gs == nil && len(a.grpclbAddrs) > 0) || - (gs != nil && !reflect.DeepEqual(a.grpclbAddrs, gs.BalancerAddresses)) { - t.Errorf("Resolved state of target: %q = %+v (State=%+v), want state.Attributes.State=%+v", a.target, state, gs, a.grpclbAddrs) - } - sc := scFromState(state) - if a.scWant != sc { - t.Errorf("Resolved service config of target: %q = %+v, want %+v", a.target, sc, a.scWant) - } - } -} - -func mutateTbl(target string) func() { - hostLookupTbl.Lock() - oldHostTblEntry := hostLookupTbl.tbl[target] - hostLookupTbl.tbl[target] = hostLookupTbl.tbl[target][:len(oldHostTblEntry)-1] - hostLookupTbl.Unlock() - txtLookupTbl.Lock() - oldTxtTblEntry := txtLookupTbl.tbl[txtPrefix+target] - txtLookupTbl.tbl[txtPrefix+target] = []string{txtAttribute + `[{"serviceConfig":{"loadBalancingPolicy": "grpclb"}}]`} - txtLookupTbl.Unlock() - - return func() { - hostLookupTbl.Lock() - hostLookupTbl.tbl[target] = oldHostTblEntry - hostLookupTbl.Unlock() - txtLookupTbl.Lock() - if len(oldTxtTblEntry) == 0 { - delete(txtLookupTbl.tbl, txtPrefix+target) - } else { - txtLookupTbl.tbl[txtPrefix+target] = oldTxtTblEntry - } - txtLookupTbl.Unlock() + }) } } -func testDNSResolveNow(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(_ time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } +// Tests the DNS resolver builder with different target names. +func (s) TestResolverBuild(t *testing.T) { tests := []struct { - target string - addrWant []resolver.Address - addrNext []resolver.Address - scWant string - scNext string + name string + target string + wantErr string }{ { - "foo.bar.com", - []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, - []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}}, - generateSC("foo.bar.com"), - `{"loadBalancingPolicy": "grpclb"}`, + name: "valid url", + target: "www.google.com", + }, + { + name: "host port", + target: "foo.bar:12345", + }, + { + name: "ipv4 address with default port", + target: "127.0.0.1", + }, + { + name: "ipv6 address without brackets and default port", + target: "::", + }, + { + name: "ipv4 address with non-default port", + target: "127.0.0.1:12345", + }, + { + name: "localhost ipv6 with brackets", + target: "[::1]:80", + }, + { + name: "ipv6 address with brackets", + target: "[2001:db8:a0b:12f0::1]:21", + }, + { + name: "empty host with port", + target: ":80", + }, + { + name: "ipv6 address with zone", + target: "[fe80::1%25lo0]:80", + }, + { + name: "url with port", + target: "golang.org:http", + }, + { + name: "ipv6 address with non integer port", + target: "[2001:db8::1]:http", + }, + { + name: "address ends with colon", + target: "[2001:db8::1]:", + wantErr: dnsinternal.ErrEndsWithColon.Error(), + }, + { + name: "address contains only a colon", + target: ":", + wantErr: dnsinternal.ErrEndsWithColon.Error(), + }, + { + name: "empty address", + target: "", + wantErr: dnsinternal.ErrMissingAddr.Error(), + }, + { + name: "invalid address", + target: "[2001:db8:a0b:12f0::1", + wantErr: "invalid target address", }, } - for _, a := range tests { - b := NewBuilder() - cc := &testClientConn{target: a.target} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("%v\n", err) - } - defer r.Close() - var state resolver.State - var cnt int - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt > 0 { - break - } - time.Sleep(time.Millisecond) - } - if cnt == 0 { - t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state) - } - if !reflect.DeepEqual(a.addrWant, state.Addresses) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) - } - sc := scFromState(state) - if a.scWant != sc { - t.Errorf("Resolved service config of target: %q = %+v, want %+v", a.target, sc, a.scWant) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + overrideTimeAfterFunc(t, 2*defaultTestTimeout) - revertTbl := mutateTbl(a.target) - r.ResolveNow(resolver.ResolveNowOptions{}) - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt == 2 { - break + b := resolver.Get("dns") + if b == nil { + t.Fatalf("Resolver for dns:/// scheme not registered") } - time.Sleep(time.Millisecond) - } - if cnt != 2 { - t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state) - } - sc = scFromState(state) - if !reflect.DeepEqual(a.addrNext, state.Addresses) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrNext) - } - if a.scNext != sc { - t.Errorf("Resolved service config of target: %q = %+v, want %+v", a.target, sc, a.scNext) - } - revertTbl() - } -} - -const colonDefaultPort = ":" + defaultPort -func testIPResolver(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(_ time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } - tests := []struct { - target string - want []resolver.Address - }{ - {"127.0.0.1", []resolver.Address{{Addr: "127.0.0.1" + colonDefaultPort}}}, - {"127.0.0.1:12345", []resolver.Address{{Addr: "127.0.0.1:12345"}}}, - {"::1", []resolver.Address{{Addr: "[::1]" + colonDefaultPort}}}, - {"[::1]:12345", []resolver.Address{{Addr: "[::1]:12345"}}}, - {"[::1]", []resolver.Address{{Addr: "[::1]:443"}}}, - {"2001:db8:85a3::8a2e:370:7334", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}}, - {"[2001:db8:85a3::8a2e:370:7334]", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]" + colonDefaultPort}}}, - {"[2001:db8:85a3::8a2e:370:7334]:12345", []resolver.Address{{Addr: "[2001:db8:85a3::8a2e:370:7334]:12345"}}}, - {"[2001:db8::1]:http", []resolver.Address{{Addr: "[2001:db8::1]:http"}}}, - // TODO(yuxuanli): zone support? - } - - for _, v := range tests { - b := NewBuilder() - cc := &testClientConn{target: v.target} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", v.target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("%v\n", err) - } - var state resolver.State - var cnt int - for { - state, cnt = cc.getState() - if cnt > 0 { - break + tcc := &testutils.ResolverClientConn{Logger: t} + r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", test.target))}, tcc, resolver.BuildOptions{}) + if err != nil { + if test.wantErr == "" { + t.Fatalf("DNS resolver build for target %q failed with error: %v", test.target, err) + } + if !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("DNS resolver build for target %q failed with error: %v, wantErr: %s", test.target, err, test.wantErr) + } + return } - time.Sleep(time.Millisecond) - } - if !reflect.DeepEqual(v.want, state.Addresses) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", v.target, state.Addresses, v.want) - } - r.ResolveNow(resolver.ResolveNowOptions{}) - for i := 0; i < 50; i++ { - state, cnt = cc.getState() - if cnt > 1 { - t.Fatalf("Unexpected second call by resolver to UpdateState. state: %v", state) + if err == nil && test.wantErr != "" { + t.Fatalf("DNS resolver build for target %q succeeded when expected to fail with error: %s", test.target, test.wantErr) } - time.Sleep(time.Millisecond) - } - r.Close() - } -} - -func TestResolveFunc(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } - tests := []struct { - addr string - want error - }{ - // TODO(yuxuanli): More false cases? - {"www.google.com", nil}, - {"foo.bar:12345", nil}, - {"127.0.0.1", nil}, - {"::", nil}, - {"127.0.0.1:12345", nil}, - {"[::1]:80", nil}, - {"[2001:db8:a0b:12f0::1]:21", nil}, - {":80", nil}, - {"127.0.0...1:12345", nil}, - {"[fe80::1%25lo0]:80", nil}, - {"golang.org:http", nil}, - {"[2001:db8::1]:http", nil}, - {"[2001:db8::1]:", errEndsWithColon}, - {":", errEndsWithColon}, - {"", errMissingAddr}, - {"[2001:db8:a0b:12f0::1", fmt.Errorf("invalid target address [2001:db8:a0b:12f0::1, error info: address [2001:db8:a0b:12f0::1:443: missing ']' in address")}, - } - - b := NewBuilder() - for _, v := range tests { - cc := &testClientConn{target: v.addr, errChan: make(chan error, 1)} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", v.addr))}, cc, resolver.BuildOptions{}) - if err == nil { r.Close() - } - if !reflect.DeepEqual(err, v.want) { - t.Errorf("Build(%q, cc, _) = %v, want %v", v.addr, err, v.want) - } + }) } } -func TestDisableServiceConfig(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } +// Tests scenarios where fetching of service config is enabled or disabled, and +// verifies that the expected update is pushed by the DNS resolver. +func (s) TestDisableServiceConfig(t *testing.T) { tests := []struct { + name string target string - scWant string + hostLookupTable map[string][]string + txtLookupTable map[string][]string disableServiceConfig bool + wantAddrs []resolver.Address + wantSC string }{ { - "foo.bar.com", - generateSC("foo.bar.com"), - false, + name: "false", + target: "foo.bar.com", + hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}}, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + disableServiceConfig: false, + wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, + wantSC: scJSON, }, { - "foo.bar.com", - "", - true, + name: "true", + target: "foo.bar.com", + hostLookupTable: map[string][]string{"foo.bar.com": {"1.2.3.4", "5.6.7.8"}}, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + disableServiceConfig: true, + wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, + wantSC: "{}", }, } - for _, a := range tests { - b := NewBuilder() - cc := &testClientConn{target: a.target} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{DisableServiceConfig: a.disableServiceConfig}) - if err != nil { - t.Fatalf("%v\n", err) - } - defer r.Close() - var cnt int - var state resolver.State - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt > 0 { - break + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + overrideTimeAfterFunc(t, 2*defaultTestTimeout) + overrideNetResolver(t, &testNetResolver{ + hostLookupTable: test.hostLookupTable, + txtLookupTable: test.txtLookupTable, + }) + + b := resolver.Get("dns") + if b == nil { + t.Fatalf("Resolver for dns:/// scheme not registered") } - time.Sleep(time.Millisecond) - } - if cnt == 0 { - t.Fatalf("UpdateState not called after 2s; aborting") - } - sc := scFromState(state) - if a.scWant != sc { - t.Errorf("Resolved service config of target: %q = %+v, want %+v", a.target, sc, a.scWant) - } + + stateCh := make(chan resolver.State, 1) + updateStateF := func(s resolver.State) error { + stateCh <- s + return nil + } + tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF} + r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", test.target))}, tcc, resolver.BuildOptions{DisableServiceConfig: test.disableServiceConfig}) + if err != nil { + t.Fatalf("Failed to build DNS resolver for target %q: %v\n", test.target, err) + } + defer r.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + verifyUpdateFromResolver(ctx, t, stateCh, test.wantAddrs, nil, test.wantSC) + }) } } -func TestTXTError(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } - defer func(v bool) { envconfig.TXTErrIgnore = v }(envconfig.TXTErrIgnore) +// Tests the case where a TXT lookup is expected to return an error. Verifies +// that errors are ignored with the corresponding env var is set. +func (s) TestTXTError(t *testing.T) { for _, ignore := range []bool{false, true} { - envconfig.TXTErrIgnore = ignore - b := NewBuilder() - cc := &testClientConn{target: "ipv4.single.fake"} // has A records but not TXT records. - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", "ipv4.single.fake"))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("%v\n", err) - } - defer r.Close() - var cnt int - var state resolver.State - for i := 0; i < 2000; i++ { - state, cnt = cc.getState() - if cnt > 0 { - break + t.Run(fmt.Sprintf("%v", ignore), func(t *testing.T) { + overrideTimeAfterFunc(t, 2*defaultTestTimeout) + overrideNetResolver(t, &testNetResolver{hostLookupTable: map[string][]string{"ipv4.single.fake": {"1.2.3.4"}}}) + + origTXTIgnore := envconfig.TXTErrIgnore + envconfig.TXTErrIgnore = ignore + defer func() { envconfig.TXTErrIgnore = origTXTIgnore }() + + // There is no entry for "ipv4.single.fake" in the txtLookupTbl + // maintained by the fake net.Resolver. So, a TXT lookup for this + // name will return an error. + _, stateCh, _ := buildResolverWithTestClientConn(t, "ipv4.single.fake") + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + var state resolver.State + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for a state update from the resolver") + case state = <-stateCh: } - time.Sleep(time.Millisecond) - } - if cnt == 0 { - t.Fatalf("UpdateState not called after 2s; aborting") - } - if !ignore && (state.ServiceConfig == nil || state.ServiceConfig.Err == nil) { - t.Errorf("state.ServiceConfig = %v; want non-nil error", state.ServiceConfig) - } else if ignore && state.ServiceConfig != nil { - t.Errorf("state.ServiceConfig = %v; want nil", state.ServiceConfig) - } - } -} -func TestDNSResolverRetry(t *testing.T) { - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } - b := NewBuilder() - target := "ipv4.single.fake" - cc := &testClientConn{target: target} - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("%v\n", err) - } - defer r.Close() - var state resolver.State - for i := 0; i < 2000; i++ { - state, _ = cc.getState() - if len(state.Addresses) == 1 { - break - } - time.Sleep(time.Millisecond) - } - if len(state.Addresses) != 1 { - t.Fatalf("UpdateState not called with 1 address after 2s; aborting. state=%v", state) - } - want := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}} - if !reflect.DeepEqual(want, state.Addresses) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want) - } - // mutate the host lookup table so the target has 0 address returned. - revertTbl := mutateTbl(target) - // trigger a resolve that will get empty address list - r.ResolveNow(resolver.ResolveNowOptions{}) - for i := 0; i < 2000; i++ { - state, _ = cc.getState() - if len(state.Addresses) == 0 { - break - } - time.Sleep(time.Millisecond) - } - if len(state.Addresses) != 0 { - t.Fatalf("UpdateState not called with 0 address after 2s; aborting. state=%v", state) - } - revertTbl() - // wait for the retry to happen in two seconds. - r.ResolveNow(resolver.ResolveNowOptions{}) - for i := 0; i < 2000; i++ { - state, _ = cc.getState() - if len(state.Addresses) == 1 { - break - } - time.Sleep(time.Millisecond) - } - if !reflect.DeepEqual(want, state.Addresses) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want) + if ignore { + if state.ServiceConfig != nil { + t.Fatalf("Received non-nil service config: %+v; want nil", state.ServiceConfig) + } + } else { + if state.ServiceConfig == nil || state.ServiceConfig.Err == nil { + t.Fatalf("Received service config %+v; want non-nil error", state.ServiceConfig) + } + } + }) } } -func TestCustomAuthority(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. - return time.NewTimer(time.Hour) - } - +// Tests different cases for a user's dial target that specifies a non-empty +// authority (or Host field of the URL). +func (s) TestCustomAuthority(t *testing.T) { tests := []struct { + name string authority string - authorityWant string - expectError bool + wantAuthority string + wantBuildErr bool }{ { - "4.3.2.1:" + defaultDNSSvrPort, - "4.3.2.1:" + defaultDNSSvrPort, - false, + name: "authority with default DNS port", + authority: "4.3.2.1:53", + wantAuthority: "4.3.2.1:53", }, { - "4.3.2.1:123", - "4.3.2.1:123", - false, + name: "authority with non-default DNS port", + authority: "4.3.2.1:123", + wantAuthority: "4.3.2.1:123", }, { - "4.3.2.1", - "4.3.2.1:" + defaultDNSSvrPort, - false, + name: "authority with no port", + authority: "4.3.2.1", + wantAuthority: "4.3.2.1:53", }, { - "::1", - "[::1]:" + defaultDNSSvrPort, - false, + name: "ipv6 authority with no port", + authority: "::1", + wantAuthority: "[::1]:53", }, { - "[::1]", - "[::1]:" + defaultDNSSvrPort, - false, + name: "ipv6 authority with brackets and no port", + authority: "[::1]", + wantAuthority: "[::1]:53", }, { - "[::1]:123", - "[::1]:123", - false, + name: "ipv6 authority with brackers and non-default DNS port", + authority: "[::1]:123", + wantAuthority: "[::1]:123", }, { - "dnsserver.com", - "dnsserver.com:" + defaultDNSSvrPort, - false, + name: "host name with no port", + authority: "dnsserver.com", + wantAuthority: "dnsserver.com:53", }, { - ":123", - "localhost:123", - false, + name: "no host port and non-default port", + authority: ":123", + wantAuthority: "localhost:123", }, { - ":", - "", - true, + name: "only colon", + authority: ":", + wantAuthority: "", + wantBuildErr: true, }, { - "[::1]:", - "", - true, + name: "ipv6 name ending in colon", + authority: "[::1]:", + wantAuthority: "", + wantBuildErr: true, }, { - "dnsserver.com:", - "", - true, + name: "host name ending in colon", + authority: "dnsserver.com:", + wantAuthority: "", + wantBuildErr: true, }, } - oldAddressDialer := addressDialer - defer func() { - addressDialer = oldAddressDialer - }() - - for _, a := range tests { - errChan := make(chan error, 1) - addressDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { - if authority != a.authorityWant { - errChan <- fmt.Errorf("wrong custom authority passed to resolver. input: %s expected: %s actual: %s", a.authority, a.authorityWant, authority) - } else { - errChan <- nil - } - return func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, errors.New("no need to dial") - } - } - mockEndpointTarget := "foo.bar.com" - b := NewBuilder() - cc := &testClientConn{target: mockEndpointTarget, errChan: make(chan error, 1)} - target := resolver.Target{ - URL: *testutils.MustParseURL(fmt.Sprintf("scheme://%s/%s", a.authority, mockEndpointTarget)), - } - r, err := b.Build(target, cc, resolver.BuildOptions{}) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + overrideTimeAfterFunc(t, 2*defaultTestTimeout) + + // Override the address dialer to verify the authority being passed. + origAddressDialer := dnsinternal.AddressDialer + errChan := make(chan error, 1) + dnsinternal.AddressDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { + if authority != test.wantAuthority { + errChan <- fmt.Errorf("wrong custom authority passed to resolver. target: %s got authority: %s want authority: %s", test.authority, authority, test.wantAuthority) + } else { + errChan <- nil + } + return func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("no need to dial") + } + } + defer func() { dnsinternal.AddressDialer = origAddressDialer }() - if err == nil { - r.Close() + b := resolver.Get("dns") + if b == nil { + t.Fatalf("Resolver for dns:/// scheme not registered") + } - err = <-errChan + tcc := &testutils.ResolverClientConn{Logger: t} + endpoint := "foo.bar.com" + target := resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns://%s/%s", test.authority, endpoint))} + r, err := b.Build(target, tcc, resolver.BuildOptions{}) + if (err != nil) != test.wantBuildErr { + t.Fatalf("DNS resolver build for target %+v returned error %v: wantErr: %v\n", target, err, test.wantBuildErr) + } if err != nil { - t.Errorf(err.Error()) + return } + defer r.Close() - if a.expectError { - t.Errorf("custom authority should have caused an error: %s", a.authority) + if err := <-errChan; err != nil { + t.Fatal(err) } - } else if !a.expectError { - t.Errorf("unexpected error using custom authority %s: %s", a.authority, err) - } + }) } } @@ -1466,60 +1083,21 @@ func TestCustomAuthority(t *testing.T) { // requests. It sets the re-resolution rate to a small value and repeatedly // calls ResolveNow() and ensures only the expected number of resolution // requests are made. - -func TestRateLimitedResolve(t *testing.T) { - defer leakcheck.Check(t) - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential - // backoff. - return time.NewTimer(time.Hour) - } - defer func(nt func(d time.Duration) *time.Timer) { - newTimerDNSResRate = nt - }(newTimerDNSResRate) - - timerChan := testutils.NewChannel() - newTimerDNSResRate = func(d time.Duration) *time.Timer { - // Will never fire on its own, allows this test to call timer - // immediately. - t := time.NewTimer(time.Hour) - timerChan.Send(t) - return t - } - - // Create a new testResolver{} for this test because we want the exact count - // of the number of times the resolver was invoked. - nc := overrideDefaultResolver(true) - defer nc() - - target := "foo.bar.com" - b := NewBuilder() - cc := &testClientConn{target: target} - - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("resolver.Build() returned error: %v\n", err) - } - defer r.Close() - - dnsR, ok := r.(*dnsResolver) - if !ok { - t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR) - } - - tr, ok := dnsR.resolver.(*testResolver) - if !ok { - t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) +func (s) TestRateLimitedResolve(t *testing.T) { + const target = "foo.bar.com" + _, timeChan := overrideTimeAfterFuncWithChannel(t) + tr := &testNetResolver{ + lookupHostCh: testutils.NewChannel(), + hostLookupTable: map[string][]string{target: {"1.2.3.4", "5.6.7.8"}}, } + overrideNetResolver(t, tr) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() + r, stateCh, _ := buildResolverWithTestClientConn(t, target) // Wait for the first resolution request to be done. This happens as part // of the first iteration of the for loop in watcher(). + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() if _, err := tr.lookupHostCh.Receive(ctx); err != nil { t.Fatalf("Timed out waiting for lookup() call.") } @@ -1532,21 +1110,19 @@ func TestRateLimitedResolve(t *testing.T) { continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer continueCancel() - if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") } - // Make the DNSMinResRate timer fire immediately (by receiving it, then - // resetting to 0), this will unblock the resolver which is currently - // blocked on the DNS Min Res Rate timer going off, which will allow it to - // continue to the next iteration of the watcher loop. - timer, err := timerChan.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + // Make the DNSMinResRate timer fire immediately, by sending the current + // time on it. This will unblock the resolver which is currently blocked on + // the DNS Min Res Rate timer going off, which will allow it to continue to + // the next iteration of the watcher loop. + select { + case timeChan <- time.Now(): + case <-ctx.Done(): + t.Fatal("Timed out waiting for the DNS resolver to block on DNS Min Res Rate to elapse") } - timerPointer := timer.(*time.Timer) - timerPointer.Reset(0) // Now that DNS Min Res Rate timer has gone off, it should lookup again. if _, err := tr.lookupHostCh.Receive(ctx); err != nil { @@ -1558,99 +1134,84 @@ func TestRateLimitedResolve(t *testing.T) { for i := 0; i < 1000; i++ { r.ResolveNow(resolver.ResolveNowOptions{}) } - - if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil { + continueCtx, continueCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer continueCancel() + if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") } // Make the DNSMinResRate timer fire immediately again. - timer, err = timerChan.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + select { + case timeChan <- time.Now(): + case <-ctx.Done(): + t.Fatal("Timed out waiting for the DNS resolver to block on DNS Min Res Rate to elapse") } - timerPointer = timer.(*time.Timer) - timerPointer.Reset(0) // Now that DNS Min Res Rate timer has gone off, it should lookup again. - if _, err = tr.lookupHostCh.Receive(ctx); err != nil { + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { t.Fatalf("Timed out waiting for lookup() call.") } wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} var state resolver.State - for { - var cnt int - state, cnt = cc.getState() - if cnt > 0 { - break - } - time.Sleep(time.Millisecond) + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for a state update from the resolver") + case state = <-stateCh: } - if !reflect.DeepEqual(state.Addresses, wantAddrs) { - t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, wantAddrs) + if !cmp.Equal(state.Addresses, wantAddrs, cmpopts.EquateEmpty()) { + t.Fatalf("Got addresses: %+v, want: %+v", state.Addresses, wantAddrs) } } -// DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error. -// Thus, test that it constantly sends errors to the grpc.ClientConn. -func TestReportError(t *testing.T) { +// Test verifies that when the DNS resolver gets an error from the underlying +// net.Resolver, it reports the error to the channel and backs off and retries. +func (s) TestReportError(t *testing.T) { + durChan, timeChan := overrideTimeAfterFuncWithChannel(t) + overrideNetResolver(t, &testNetResolver{}) + const target = "notfoundaddress" - defer func(nt func(d time.Duration) *time.Timer) { - newTimer = nt - }(newTimer) - timerChan := testutils.NewChannel() - newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, allows this test to call timer immediately. - t := time.NewTimer(time.Hour) - timerChan.Send(t) - return t - } - cc := &testClientConn{target: target, errChan: make(chan error)} - totalTimesCalledError := 0 - b := NewBuilder() - r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{}) - if err != nil { - t.Fatalf("Error building resolver for target %v: %v", target, err) - } + _, _, errorCh := buildResolverWithTestClientConn(t, target) + // Should receive first error. - err = <-cc.errChan - if !strings.Contains(err.Error(), "hostLookup error") { - t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) - } - totalTimesCalledError++ ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() - timer, err := timerChan.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) - } - timerPointer := timer.(*time.Timer) - timerPointer.Reset(0) - defer r.Close() - - // Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error. - for i := 0; i < 10; i++ { - // Should call ReportError(). - err = <-cc.errChan + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for an error from the resolver") + case err := <-errorCh: if !strings.Contains(err.Error(), "hostLookup error") { t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } - totalTimesCalledError++ - timer, err := timerChan.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) - } - timerPointer := timer.(*time.Timer) - timerPointer.Reset(0) } - if totalTimesCalledError != 11 { - t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError) - } - // Clean up final watcher iteration. - <-cc.errChan - _, err = timerChan.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + // Expect the DNS resolver to backoff and attempt to re-resolve. Every time, + // the DNS resolver will receive the same error from the net.Resolver and is + // expected to push it to the channel. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + const retries = 10 + var prevDur time.Duration + for i := 0; i < retries; i++ { + select { + case <-ctx.Done(): + t.Fatalf("(Iteration: %d): Timeout when waiting for DNS resolver to backoff", i) + case dur := <-durChan: + if dur <= prevDur { + t.Fatalf("(Iteration: %d): Unexpected decrease in amount of time to backoff", i) + } + } + + // Unblock the DNS resolver's backoff by pushing the current time. + timeChan <- time.Now() + + select { + case <-ctx.Done(): + t.Fatal("Timeout when waiting for an error from the resolver") + case err := <-errorCh: + if !strings.Contains(err.Error(), "hostLookup error") { + t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) + } + } } } diff --git a/internal/resolver/dns/fake_net_resolver_test.go b/internal/resolver/dns/fake_net_resolver_test.go new file mode 100644 index 000000000000..a3be31607b39 --- /dev/null +++ b/internal/resolver/dns/fake_net_resolver_test.go @@ -0,0 +1,123 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package dns_test + +import ( + "context" + "net" + "sync" + + "google.golang.org/grpc/internal/testutils" +) + +// A fake implementation of the internal.NetResolver interface for use in tests. +type testNetResolver struct { + // A write to this channel is made when this resolver receives a resolution + // request. Tests can rely on reading from this channel to be notified about + // resolution requests instead of sleeping for a predefined period of time. + lookupHostCh *testutils.Channel + + mu sync.Mutex + hostLookupTable map[string][]string // Name --> list of addresses + srvLookupTable map[string][]*net.SRV // Name --> list of SRV records + txtLookupTable map[string][]string // Name --> service config for TXT record +} + +func (tr *testNetResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + if tr.lookupHostCh != nil { + tr.lookupHostCh.Send(nil) + } + + tr.mu.Lock() + defer tr.mu.Unlock() + + if addrs, ok := tr.hostLookupTable[host]; ok { + return addrs, nil + } + return nil, &net.DNSError{ + Err: "hostLookup error", + Name: host, + Server: "fake", + IsTemporary: true, + } +} + +func (tr *testNetResolver) UpdateHostLookupTable(table map[string][]string) { + tr.mu.Lock() + tr.hostLookupTable = table + tr.mu.Unlock() +} + +func (tr *testNetResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { + tr.mu.Lock() + defer tr.mu.Unlock() + + cname := "_" + service + "._" + proto + "." + name + if srvs, ok := tr.srvLookupTable[cname]; ok { + return cname, srvs, nil + } + return "", nil, &net.DNSError{ + Err: "srvLookup error", + Name: cname, + Server: "fake", + IsTemporary: true, + } +} + +func (tr *testNetResolver) LookupTXT(ctx context.Context, host string) ([]string, error) { + tr.mu.Lock() + defer tr.mu.Unlock() + + if sc, ok := tr.txtLookupTable[host]; ok { + return sc, nil + } + return nil, &net.DNSError{ + Err: "txtLookup error", + Name: host, + Server: "fake", + IsTemporary: true, + } +} + +func (tr *testNetResolver) UpdateTXTLookupTable(table map[string][]string) { + tr.mu.Lock() + tr.txtLookupTable = table + tr.mu.Unlock() +} + +// txtRecordServiceConfig generates a slice of strings (aggregately representing +// a single service config file) for the input config string, that represents +// the result from a real DNS TXT record lookup. +func txtRecordServiceConfig(cfg string) []string { + // In DNS, service config is encoded in a TXT record via the mechanism + // described in RFC-1464 using the attribute name grpc_config. + b := append([]byte("grpc_config="), []byte(cfg)...) + + // Split b into multiple strings, each with a max of 255 bytes, which is + // the DNS TXT record limit. + var r []string + for i := 0; i < len(b); i += txtBytesLimit { + if i+txtBytesLimit > len(b) { + r = append(r, string(b[i:])) + } else { + r = append(r, string(b[i:i+txtBytesLimit])) + } + } + return r +} diff --git a/internal/resolver/dns/internal/internal.go b/internal/resolver/dns/internal/internal.go new file mode 100644 index 000000000000..c7fc557d00c1 --- /dev/null +++ b/internal/resolver/dns/internal/internal.go @@ -0,0 +1,70 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package internal contains functionality internal to the dns resolver package. +package internal + +import ( + "context" + "errors" + "net" + "time" +) + +// NetResolver groups the methods on net.Resolver that are used by the DNS +// resolver implementation. This allows the default net.Resolver instance to be +// overidden from tests. +type NetResolver interface { + LookupHost(ctx context.Context, host string) (addrs []string, err error) + LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) + LookupTXT(ctx context.Context, name string) (txts []string, err error) +} + +var ( + // ErrMissingAddr is the error returned when building a DNS resolver when + // the provided target name is empty. + ErrMissingAddr = errors.New("dns resolver: missing address") + + // ErrEndsWithColon is the error returned when building a DNS resolver when + // the provided target name ends with a colon that is supposed to be the + // separator between host and port. E.g. "::" is a valid address as it is + // an IPv6 address (host only) and "[::]:" is invalid as it ends with a + // colon as the host and port separator + ErrEndsWithColon = errors.New("dns resolver: missing port after port-separator colon") +) + +// The following vars are overridden from tests. +var ( + // MinResolutionRate is the minimum rate at which re-resolutions are + // allowed. This helps to prevent excessive re-resolution. + MinResolutionRate = 30 * time.Second + + // TimeAfterFunc is used by the DNS resolver to wait for the given duration + // to elapse. In non-test code, this is implemented by time.After. In test + // code, this can be used to control the amount of time the resolver is + // blocked waiting for the duration to elapse. + TimeAfterFunc func(time.Duration) <-chan time.Time + + // NewNetResolver returns the net.Resolver instance for the given target. + NewNetResolver func(string) (NetResolver, error) + + // AddressDialer is the dialer used to dial the DNS server. It accepts the + // Host portion of the URL corresponding to the user's dial target and + // returns a dial function. + AddressDialer func(address string) func(context.Context, string, string) (net.Conn, error) +) diff --git a/internal/testutils/balancer.go b/internal/testutils/balancer.go index 43bbbf9ae560..c65be16be4b6 100644 --- a/internal/testutils/balancer.go +++ b/internal/testutils/balancer.go @@ -30,16 +30,9 @@ import ( "google.golang.org/grpc/resolver" ) -// testingLogger wraps the logging methods from testing.T. -type testingLogger interface { - Log(args ...any) - Logf(format string, args ...any) - Errorf(format string, args ...any) -} - // TestSubConn implements the SubConn interface, to be used in tests. type TestSubConn struct { - tcc *TestClientConn // the CC that owns this SubConn + tcc *BalancerClientConn // the CC that owns this SubConn id string ConnectCh chan struct{} stateListener func(balancer.SubConnState) @@ -98,9 +91,9 @@ func (tsc *TestSubConn) String() string { return tsc.id } -// TestClientConn is a mock balancer.ClientConn used in tests. -type TestClientConn struct { - logger testingLogger +// BalancerClientConn is a mock balancer.ClientConn used in tests. +type BalancerClientConn struct { + logger Logger NewSubConnAddrsCh chan []resolver.Address // the last 10 []Address to create subconn. NewSubConnCh chan *TestSubConn // the last 10 subconn created. @@ -114,9 +107,9 @@ type TestClientConn struct { subConnIdx int } -// NewTestClientConn creates a TestClientConn. -func NewTestClientConn(t *testing.T) *TestClientConn { - return &TestClientConn{ +// NewBalancerClientConn creates a BalancerClientConn. +func NewBalancerClientConn(t *testing.T) *BalancerClientConn { + return &BalancerClientConn{ logger: t, NewSubConnAddrsCh: make(chan []resolver.Address, 10), @@ -131,7 +124,7 @@ func NewTestClientConn(t *testing.T) *TestClientConn { } // NewSubConn creates a new SubConn. -func (tcc *TestClientConn) NewSubConn(a []resolver.Address, o balancer.NewSubConnOptions) (balancer.SubConn, error) { +func (tcc *BalancerClientConn) NewSubConn(a []resolver.Address, o balancer.NewSubConnOptions) (balancer.SubConn, error) { sc := &TestSubConn{ tcc: tcc, id: fmt.Sprintf("sc%d", tcc.subConnIdx), @@ -156,13 +149,13 @@ func (tcc *TestClientConn) NewSubConn(a []resolver.Address, o balancer.NewSubCon // RemoveSubConn is a nop; tests should all be updated to use sc.Shutdown() // instead. -func (tcc *TestClientConn) RemoveSubConn(sc balancer.SubConn) { +func (tcc *BalancerClientConn) RemoveSubConn(sc balancer.SubConn) { tcc.logger.Errorf("RemoveSubConn(%v) called unexpectedly", sc) } // UpdateAddresses updates the addresses on the SubConn. -func (tcc *TestClientConn) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { - tcc.logger.Logf("testClientConn: UpdateAddresses(%v, %+v)", sc, addrs) +func (tcc *BalancerClientConn) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { + tcc.logger.Logf("testutils.BalancerClientConn: UpdateAddresses(%v, %+v)", sc, addrs) select { case tcc.UpdateAddressesAddrsCh <- addrs: default: @@ -170,8 +163,8 @@ func (tcc *TestClientConn) UpdateAddresses(sc balancer.SubConn, addrs []resolver } // UpdateState updates connectivity state and picker. -func (tcc *TestClientConn) UpdateState(bs balancer.State) { - tcc.logger.Logf("testClientConn: UpdateState(%v)", bs) +func (tcc *BalancerClientConn) UpdateState(bs balancer.State) { + tcc.logger.Logf("testutils.BalancerClientConn: UpdateState(%v)", bs) select { case <-tcc.NewStateCh: default: @@ -186,7 +179,7 @@ func (tcc *TestClientConn) UpdateState(bs balancer.State) { } // ResolveNow panics. -func (tcc *TestClientConn) ResolveNow(o resolver.ResolveNowOptions) { +func (tcc *BalancerClientConn) ResolveNow(o resolver.ResolveNowOptions) { select { case <-tcc.ResolveNowCh: default: @@ -195,14 +188,14 @@ func (tcc *TestClientConn) ResolveNow(o resolver.ResolveNowOptions) { } // Target panics. -func (tcc *TestClientConn) Target() string { +func (tcc *BalancerClientConn) Target() string { panic("not implemented") } // WaitForErrPicker waits until an error picker is pushed to this ClientConn. // Returns error if the provided context expires or a non-error picker is pushed // to the ClientConn. -func (tcc *TestClientConn) WaitForErrPicker(ctx context.Context) error { +func (tcc *BalancerClientConn) WaitForErrPicker(ctx context.Context) error { select { case <-ctx.Done(): return errors.New("timeout when waiting for an error picker") @@ -218,7 +211,7 @@ func (tcc *TestClientConn) WaitForErrPicker(ctx context.Context) error { // ClientConn with the error matching the wanted error. Returns an error if // the provided context expires, including the last received picker error (if // any). -func (tcc *TestClientConn) WaitForPickerWithErr(ctx context.Context, want error) error { +func (tcc *BalancerClientConn) WaitForPickerWithErr(ctx context.Context, want error) error { lastErr := errors.New("received no picker") for { select { @@ -235,7 +228,7 @@ func (tcc *TestClientConn) WaitForPickerWithErr(ctx context.Context, want error) // WaitForConnectivityState waits until the state pushed to this ClientConn // matches the wanted state. Returns an error if the provided context expires, // including the last received state (if any). -func (tcc *TestClientConn) WaitForConnectivityState(ctx context.Context, want connectivity.State) error { +func (tcc *BalancerClientConn) WaitForConnectivityState(ctx context.Context, want connectivity.State) error { var lastState connectivity.State = -1 for { select { @@ -255,7 +248,7 @@ func (tcc *TestClientConn) WaitForConnectivityState(ctx context.Context, want co // is pending) to be considered. Returns an error if the provided context // expires, including the last received error from IsRoundRobin or the picker // (if any). -func (tcc *TestClientConn) WaitForRoundRobinPicker(ctx context.Context, want ...balancer.SubConn) error { +func (tcc *BalancerClientConn) WaitForRoundRobinPicker(ctx context.Context, want ...balancer.SubConn) error { lastErr := errors.New("received no picker") for { select { @@ -294,7 +287,7 @@ func (tcc *TestClientConn) WaitForRoundRobinPicker(ctx context.Context, want ... // WaitForPicker waits for a picker that results in f returning nil. If the // context expires, returns the last error returned by f (if any). -func (tcc *TestClientConn) WaitForPicker(ctx context.Context, f func(balancer.Picker) error) error { +func (tcc *BalancerClientConn) WaitForPicker(ctx context.Context, f func(balancer.Picker) error) error { lastErr := errors.New("received no picker") for { select { diff --git a/internal/testutils/resolver.go b/internal/testutils/resolver.go new file mode 100644 index 000000000000..943436855004 --- /dev/null +++ b/internal/testutils/resolver.go @@ -0,0 +1,70 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils + +import ( + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +// Logger wraps the logging methods from testing.T. +type Logger interface { + Log(args ...any) + Logf(format string, args ...any) + Errorf(format string, args ...any) +} + +// ResolverClientConn is a fake implemetation of the resolver.ClientConn +// interface to be used in tests. +type ResolverClientConn struct { + resolver.ClientConn // Embedding the interface to avoid implementing deprecated methods. + + Logger Logger // Tests should pass testing.T for this. + UpdateStateF func(resolver.State) error // Invoked when resolver pushes a state update. + ReportErrorF func(err error) // Invoked when resolver pushes an error. +} + +// UpdateState invokes the test specified callback with the update received from +// the resolver. If the callback returns a non-nil error, the same will be +// propagated to the resolver. +func (t *ResolverClientConn) UpdateState(s resolver.State) error { + t.Logger.Logf("testutils.ResolverClientConn: UpdateState(%s)", pretty.ToJSON(s)) + + if t.UpdateStateF != nil { + return t.UpdateStateF(s) + } + return nil +} + +// ReportError pushes the error received from the resolver on to ErrorCh. +func (t *ResolverClientConn) ReportError(err error) { + t.Logger.Logf("testutils.ResolverClientConn: ReportError(%v)", err) + + if t.ReportErrorF != nil { + t.ReportErrorF(err) + } +} + +// ParseServiceConfig parses the provided service by delegating the work to the +// implementation in the grpc package. +func (t *ResolverClientConn) ParseServiceConfig(jsonSC string) *serviceconfig.ParseResult { + return internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(jsonSC) +} diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 2f0099d0b6c5..1edf3b8b857a 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -92,7 +92,7 @@ func (s) TestDropByCategory(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -251,7 +251,7 @@ func (s) TestDropCircuitBreaking(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -363,7 +363,7 @@ func (s) TestPickerUpdateAfterClose(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) // Create a stub balancer which waits for the cluster_impl policy to be @@ -436,7 +436,7 @@ func (s) TestClusterNameInAddressAttributes(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -511,7 +511,7 @@ func (s) TestReResolution(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -578,7 +578,7 @@ func (s) TestLoadReporting(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -692,7 +692,7 @@ func (s) TestUpdateLRSServer(t *testing.T) { xdsC := fakeclient.NewClient() builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() diff --git a/xds/internal/balancer/clustermanager/clustermanager_test.go b/xds/internal/balancer/clustermanager/clustermanager_test.go index 39e32d60993d..a00a2836060a 100644 --- a/xds/internal/balancer/clustermanager/clustermanager_test.go +++ b/xds/internal/balancer/clustermanager/clustermanager_test.go @@ -75,7 +75,7 @@ func testPick(t *testing.T, p balancer.Picker, info balancer.PickInfo, wantSC ba } func TestClusterPicks(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) @@ -154,7 +154,7 @@ func TestClusterPicks(t *testing.T) { // TestConfigUpdateAddCluster covers the cases the balancer receives config // update with extra clusters. func TestConfigUpdateAddCluster(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) @@ -312,7 +312,7 @@ func TestConfigUpdateAddCluster(t *testing.T) { // TestRoutingConfigUpdateDeleteAll covers the cases the balancer receives // config update with no clusters. Pick should fail with details in error. func TestRoutingConfigUpdateDeleteAll(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) @@ -498,7 +498,7 @@ func TestClusterManagerForwardsBalancerBuildOptions(t *testing.T) { }, }) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, bOpts) @@ -558,7 +558,7 @@ func init() { // TestInitialIdle covers the case that if the child reports Idle, the overall // state will be Idle. func TestInitialIdle(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) @@ -605,7 +605,7 @@ func TestInitialIdle(t *testing.T) { // it's state and completes the graceful switch process the new picker should // reflect this change. func TestClusterGracefulSwitch(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(balancerName) parser := builder.(balancer.ConfigParser) bal := builder.Build(cc, balancer.BuildOptions{}) @@ -708,17 +708,17 @@ func TestClusterGracefulSwitch(t *testing.T) { // tcc wraps a testutils.TestClientConn but stores all state transitions in a // slice. type tcc struct { - *testutils.TestClientConn + *testutils.BalancerClientConn states []balancer.State } func (t *tcc) UpdateState(bs balancer.State) { t.states = append(t.states, bs) - t.TestClientConn.UpdateState(bs) + t.BalancerClientConn.UpdateState(bs) } func (s) TestUpdateStatePauses(t *testing.T) { - cc := &tcc{TestClientConn: testutils.NewTestClientConn(t)} + cc := &tcc{BalancerClientConn: testutils.NewBalancerClientConn(t)} balFuncs := stub.BalancerFuncs{ UpdateClientConnState: func(bd *stub.BalancerData, s balancer.ClientConnState) error { diff --git a/xds/internal/balancer/outlierdetection/balancer_test.go b/xds/internal/balancer/outlierdetection/balancer_test.go index c6f0ac7ee128..32c3a378d7ea 100644 --- a/xds/internal/balancer/outlierdetection/balancer_test.go +++ b/xds/internal/balancer/outlierdetection/balancer_test.go @@ -543,13 +543,13 @@ type subConnWithState struct { state balancer.SubConnState } -func setup(t *testing.T) (*outlierDetectionBalancer, *testutils.TestClientConn, func()) { +func setup(t *testing.T) (*outlierDetectionBalancer, *testutils.BalancerClientConn, func()) { t.Helper() builder := balancer.Get(Name) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", Name) } - tcc := testutils.NewTestClientConn(t) + tcc := testutils.NewBalancerClientConn(t) odB := builder.Build(tcc, balancer.BuildOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefChannel, time.Now().Unix(), nil)}) return odB.(*outlierDetectionBalancer), tcc, odB.Close } diff --git a/xds/internal/balancer/priority/balancer_test.go b/xds/internal/balancer/priority/balancer_test.go index efdd280d029e..5d4596fd335c 100644 --- a/xds/internal/balancer/priority/balancer_test.go +++ b/xds/internal/balancer/priority/balancer_test.go @@ -81,7 +81,7 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -195,7 +195,7 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -358,7 +358,7 @@ func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -442,7 +442,7 @@ func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -543,7 +543,7 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -645,7 +645,7 @@ func (s) TestPriority_InitTimeout(t *testing.T) { } }()() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -716,7 +716,7 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { } }()() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -875,7 +875,7 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -967,7 +967,7 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { }(DefaultPriorityInitTimeout) DefaultPriorityInitTimeout = testPriorityInitTimeout - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1013,7 +1013,7 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1110,7 +1110,7 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1206,7 +1206,7 @@ func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1310,7 +1310,7 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { } }()() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1409,7 +1409,7 @@ func (s) TestPriority_ChildPolicyChange(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1508,7 +1508,7 @@ func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1551,7 +1551,7 @@ func (s) TestPriority_IgnoreReresolutionRequest(t *testing.T) { }, }) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1650,7 +1650,7 @@ func (s) TestPriority_IgnoreReresolutionRequestTwoChildren(t *testing.T) { }, }) - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1772,7 +1772,7 @@ func (s) TestPriority_HighPriorityInitIdle(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1838,7 +1838,7 @@ func (s) TestPriority_AddLowPriorityWhenHighIsInIdle(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1919,7 +1919,7 @@ func (s) TestPriority_HighPriorityUpdatesWhenLowInUse(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() diff --git a/xds/internal/balancer/priority/ignore_resolve_now_test.go b/xds/internal/balancer/priority/ignore_resolve_now_test.go index 5a0083147888..661e9d052afd 100644 --- a/xds/internal/balancer/priority/ignore_resolve_now_test.go +++ b/xds/internal/balancer/priority/ignore_resolve_now_test.go @@ -28,7 +28,7 @@ import ( ) func (s) TestIgnoreResolveNowClientConn(t *testing.T) { - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) ignoreCC := newIgnoreResolveNowClientConn(cc, false) // Call ResolveNow() on the CC, it should be forwarded. diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go index 16872dd346b1..a1edfe5d228a 100644 --- a/xds/internal/balancer/ringhash/ringhash_test.go +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -67,9 +67,9 @@ func ctxWithHash(h uint64) context.Context { } // setupTest creates the balancer, and does an initial sanity check. -func setupTest(t *testing.T, addrs []resolver.Address) (*testutils.TestClientConn, balancer.Balancer, balancer.Picker) { +func setupTest(t *testing.T, addrs []resolver.Address) (*testutils.BalancerClientConn, balancer.Balancer, balancer.Picker) { t.Helper() - cc := testutils.NewTestClientConn(t) + cc := testutils.NewBalancerClientConn(t) builder := balancer.Get(Name) b := builder.Build(cc, balancer.BuildOptions{}) if b == nil { diff --git a/xds/internal/balancer/wrrlocality/balancer_test.go b/xds/internal/balancer/wrrlocality/balancer_test.go index f0da7413bdb8..ab4167350322 100644 --- a/xds/internal/balancer/wrrlocality/balancer_test.go +++ b/xds/internal/balancer/wrrlocality/balancer_test.go @@ -167,7 +167,7 @@ func (s) TestUpdateClientConnState(t *testing.T) { if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", Name) } - tcc := testutils.NewTestClientConn(t) + tcc := testutils.NewBalancerClientConn(t) bal := builder.Build(tcc, balancer.BuildOptions{}) defer bal.Close() wrrL := bal.(*wrrLocalityBalancer)