Skip to content

Commit

Permalink
upstream: add staleness cache
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Feb 20, 2024
1 parent d0c2995 commit f4d9592
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 17 deletions.
32 changes: 17 additions & 15 deletions upstream/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (r *UpstreamResolver) LookupNetIP(

host = dns.Fqdn(strings.ToLower(host))

res, err := r.resolveIP(ctx, network, host)
res, err := r.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}
Expand All @@ -151,19 +151,19 @@ type ipResult struct {
addrs []netip.Addr
}

// resolveIP performs a DNS lookup of host and returns the result. network must
// be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6] or
// lookupNetIP performs a DNS lookup of host and returns the result. network
// must be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6], or
// [bootstrap.NetworkIP]. host must be in a lower-case FQDN form.
//
// TODO(e.burkov): Use context.
func (r *UpstreamResolver) resolveIP(
func (r *UpstreamResolver) lookupNetIP(
_ context.Context,
network bootstrap.Network,
host string,
) (result *ipResult, err error) {
switch network {
case bootstrap.NetworkIP4, bootstrap.NetworkIP6:
return r.resolve(host, network)
return r.request(host, network)
case bootstrap.NetworkIP:
// Go on.
default:
Expand Down Expand Up @@ -192,16 +192,14 @@ func (r *UpstreamResolver) resolveIP(
return result, errors.Join(errs...)
}

// resolve performs a single DNS lookup of host and returns all the valid
// request performs a single DNS lookup of host and returns all the valid
// addresses from the answer section of the response. network must be either
// "ip4" or "ip6". host must be in a lower-case FQDN form.
// [bootstrap.NetworkIP4], or [bootstrap.NetworkIP6]. host must be in a
// lower-case FQDN form.
//
// TODO(e.burkov): Consider NS and Extra sections when setting TTL. Check out
// what RFCs say about it.
func (r *UpstreamResolver) resolve(
host string,
n bootstrap.Network,
) (res *ipResult, err error) {
func (r *UpstreamResolver) request(host string, n bootstrap.Network) (res *ipResult, err error) {
var qtype uint16
switch n {
case bootstrap.NetworkIP4:
Expand Down Expand Up @@ -231,7 +229,10 @@ func (r *UpstreamResolver) resolve(
return res, err
}

res = &ipResult{}
res = &ipResult{
expire: time.Now(),
addrs: make([]netip.Addr, 0, len(resp.Answer)),
}
var minTTL uint32 = math.MaxUint32

for _, rr := range resp.Answer {
Expand All @@ -240,17 +241,18 @@ func (r *UpstreamResolver) resolve(
continue
}

minTTL = min(minTTL, rr.Header().Ttl)
res.addrs = append(res.addrs, ip)
minTTL = min(minTTL, rr.Header().Ttl)
}
res.expire = res.expire.Add(time.Duration(minTTL) * time.Second)

return res, nil
}

// resolveAsync performs a single DNS lookup and sends the result to ch. It's
// intended to be used as a goroutine.
func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) {
res, err := r.resolve(host, network)
res, err := r.request(host, network)
if err != nil {
resCh <- err
} else {
Expand Down Expand Up @@ -302,7 +304,7 @@ func (r *CachingResolver) LookupNetIP(
return addrs, nil
}

newRes, err := r.resolver.resolveIP(ctx, network, host)
newRes, err := r.resolver.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}
Expand Down
102 changes: 100 additions & 2 deletions upstream/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,28 @@ package upstream

import (
"context"
"net/netip"
"testing"
"time"

"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewUpstreamResolver(t *testing.T) {
r, err := NewUpstreamResolver("1.1.1.1:53", &Options{Timeout: 3 * time.Second})
require.NoError(t, err)
ups := &FakeUpstream{
OnAddress: func() (_ string) { panic("not implemented") },
OnClose: func() (_ error) { panic("not implemented") },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return respondToTestMessage(req), nil
},
}

r := &UpstreamResolver{Upstream: ups}

ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
require.NoError(t, err)
Expand Down Expand Up @@ -90,3 +101,90 @@ func TestNewUpstreamResolver_validity(t *testing.T) {
})
}
}

func TestCachingResolver_staleness(t *testing.T) {
ip4 := netip.MustParseAddr("1.2.3.4")
ip6 := netip.MustParseAddr("2001:db8::1")

const (
smallTTL = 10 * time.Second
largeTTL = 1000 * time.Second

fqdn = "cloudflare-dns.com."
)

onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)

hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: req.Question[0].Qtype,
Class: dns.ClassINET,
}
var rr dns.RR
switch q := req.Question[0]; q.Qtype {
case dns.TypeA:
hdr.Ttl = uint32(smallTTL.Seconds())
rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()}
case dns.TypeAAAA:
hdr.Ttl = uint32(largeTTL.Seconds())
rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()}
default:
require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype)
}
resp.Answer = append(resp.Answer, rr)

return resp, nil
}

ups := &FakeUpstream{
OnAddress: func() (_ string) { panic("not implemented") },
OnClose: func() (_ error) { panic("not implemented") },
OnExchange: onExchange,
}

r := NewCachingResolver(&UpstreamResolver{Upstream: ups})

require.True(t, t.Run("resolve", func(t *testing.T) {
testCases := []struct {
name string
network bootstrap.Network
want []netip.Addr
}{{
name: "ip4",
network: bootstrap.NetworkIP4,
want: []netip.Addr{ip4},
}, {
name: "ip6",
network: bootstrap.NetworkIP6,
want: []netip.Addr{ip6},
}, {
name: "both",
network: bootstrap.NetworkIP,
want: []netip.Addr{ip4, ip6},
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.name != "both" {
t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`)
}

res, err := r.LookupNetIP(context.Background(), tc.network, fqdn)
require.NoError(t, err)

assert.ElementsMatch(t, tc.want, res)
})
}
}))

t.Run("staleness", func(t *testing.T) {
r.mu.Lock()
defer r.mu.Unlock()

require.Contains(t, r.cached, fqdn)

cached := r.cached[fqdn]
assert.Less(t, time.Until(cached.expire), smallTTL)
})
}
30 changes: 30 additions & 0 deletions upstream/upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,36 @@ func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}

// FakeUpstream is a fake [Upstream] implementation for tests.
//
// TODO(e.burkov): Move this into some fake package in this module or the
// golibs.
//
// TODO(e.burkov): Replace the actual upstreams with this in external tests.
type FakeUpstream struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
OnClose func() (err error)
}

// type check
var _ Upstream = (*FakeUpstream)(nil)

// Address implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Address() (addr string) {
return u.OnAddress()
}

// Exchange implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}

// Close implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Close() (err error) {
return u.OnClose()
}

func TestUpstream_bootstrapTimeout(t *testing.T) {
const (
timeout = 100 * time.Millisecond
Expand Down

0 comments on commit f4d9592

Please # to comment.