Skip to content

Commit

Permalink
all: fix resolver cache
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Feb 19, 2024
1 parent d918c7f commit d0c2995
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 42 deletions.
1 change: 1 addition & 0 deletions internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func ResolveDialContext(
defer cancel()
}

// TODO(e.burkov): Use network properly, perhaps, pass it through options.
ips, err := r.LookupNetIP(ctx, NetworkIP, host)
if err != nil {
return nil, fmt.Errorf("resolving hostname: %w", err)
Expand Down
73 changes: 31 additions & 42 deletions upstream/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package upstream
import (
"context"
"fmt"
"math"
"net/netip"
"net/url"
"strings"
Expand Down Expand Up @@ -135,35 +136,19 @@ func (r *UpstreamResolver) LookupNetIP(

host = dns.Fqdn(strings.ToLower(host))

rr, err := r.resolveIP(ctx, network, host)
res, err := r.resolveIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}

for _, ip := range rr {
ips = append(ips, ip.addr)
}

return ips, err
return res.addrs, err
}

// ipResult reflects a single A/AAAA record from the DNS response. It's used
// to cache the results of lookups.
type ipResult struct {
addr netip.Addr
expire time.Time
}

// filterExpired returns the addresses from res that are not expired yet. It
// returns nil if all the addresses are expired.
func filterExpired(res []ipResult, now time.Time) (filtered []netip.Addr) {
for _, r := range res {
if r.expire.After(now) {
filtered = append(filtered, r.addr)
}
}

return filtered
addrs []netip.Addr
}

// resolveIP performs a DNS lookup of host and returns the result. network must
Expand All @@ -175,32 +160,36 @@ func (r *UpstreamResolver) resolveIP(
_ context.Context,
network bootstrap.Network,
host string,
) (rr []ipResult, err error) {
) (result *ipResult, err error) {
switch network {
case bootstrap.NetworkIP4, bootstrap.NetworkIP6:
return r.resolve(host, network)
case bootstrap.NetworkIP:
// Go on.
default:
return nil, fmt.Errorf("unsupported network %s", network)
return result, fmt.Errorf("unsupported network %s", network)
}

resCh := make(chan any, 2)
go r.resolveAsync(resCh, host, bootstrap.NetworkIP4)
go r.resolveAsync(resCh, host, bootstrap.NetworkIP6)

var errs []error
result = &ipResult{}

for i := 0; i < 2; i++ {
switch res := <-resCh; res := res.(type) {
case error:
errs = append(errs, res)
case []ipResult:
rr = append(rr, res...)
case *ipResult:
if res.expire.Before(result.expire) {
result.expire = res.expire
}
result.addrs = append(result.addrs, res.addrs...)
}
}

return rr, errors.Join(errs...)
return result, errors.Join(errs...)
}

// resolve performs a single DNS lookup of host and returns all the valid
Expand All @@ -212,7 +201,7 @@ func (r *UpstreamResolver) resolveIP(
func (r *UpstreamResolver) resolve(
host string,
n bootstrap.Network,
) (res []ipResult, err error) {
) (res *ipResult, err error) {
var qtype uint16
switch n {
case bootstrap.NetworkIP4:
Expand All @@ -235,24 +224,24 @@ func (r *UpstreamResolver) resolve(
}},
}

// As per [upstream.Exchange] documentation, the response is always returned
// As per [Upstream.Exchange] documentation, the response is always returned
// if no error occurred.
resp, err := r.Exchange(req)
if err != nil {
return nil, err
return res, err
}

now := time.Now()
res = &ipResult{}
var minTTL uint32 = math.MaxUint32

for _, rr := range resp.Answer {
ip := proxyutil.IPFromRR(rr)
if !ip.IsValid() {
continue
}

res = append(res, ipResult{
addr: ip,
expire: now.Add(time.Duration(rr.Header().Ttl) * time.Second),
})
minTTL = min(minTTL, rr.Header().Ttl)
res.addrs = append(res.addrs, ip)
}

return res, nil
Expand All @@ -279,22 +268,27 @@ type CachingResolver struct {
mu *sync.RWMutex

// cached is the set of cached results sorted by [resolveResult.name].
cached map[string][]ipResult
//
// TODO(e.burkov): Use expiration cache.
cached map[string]*ipResult
}

// NewCachingResolver creates a new caching resolver that uses r for lookups.
func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) {
return &CachingResolver{
resolver: r,
mu: &sync.RWMutex{},
cached: map[string][]ipResult{},
cached: map[string]*ipResult{},
}
}

// type check
var _ Resolver = (*CachingResolver)(nil)

// LookupNetIP implements the [Resolver] interface for *CachingResolver.
//
// TODO(e.burkov): It may appear that several concurrent lookup results rewrite
// each other in the cache.
func (r *CachingResolver) LookupNetIP(
ctx context.Context,
network bootstrap.Network,
Expand All @@ -313,17 +307,12 @@ func (r *CachingResolver) LookupNetIP(
return []netip.Addr{}, err
}

addrs = filterExpired(newRes, now)
if len(addrs) == 0 {
return []netip.Addr{}, nil
}

r.mu.Lock()
defer r.mu.Unlock()

r.cached[host] = newRes

return addrs, nil
return newRes.addrs, nil
}

// findCached returns the cached addresses for host if it's not expired yet, and
Expand All @@ -333,9 +322,9 @@ func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.
defer r.mu.RUnlock()

res, ok := r.cached[host]
if !ok {
if !ok || res.expire.Before(now) {
return nil
}

return filterExpired(res, now)
return res.addrs
}

0 comments on commit d0c2995

Please # to comment.