Skip to content

Commit

Permalink
reverseproxy: Always return new upstreams (fix #5736)
Browse files Browse the repository at this point in the history
  • Loading branch information
mholt committed Aug 16, 2023
1 parent d6f86cc commit be299bb
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions modules/caddyhttp/reverseproxy/upstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
cached := srvs[suAddr]
srvsMu.RUnlock()
if cached.isFresh() {
return cached.upstreams, nil
return allNew(cached.upstreams), nil
}

// otherwise, obtain a write-lock to update the cached value
Expand All @@ -126,7 +126,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock
cached = srvs[suAddr]
if cached.isFresh() {
return cached.upstreams, nil
return allNew(cached.upstreams), nil
}

su.logger.Debug("refreshing SRV upstreams",
Expand All @@ -145,15 +145,15 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
su.logger.Warn("SRV records filtered", zap.Error(err))
}

upstreams := make([]*Upstream, len(records))
upstreams := make([]Upstream, len(records))
for i, rec := range records {
su.logger.Debug("discovered SRV record",
zap.String("target", rec.Target),
zap.Uint16("port", rec.Port),
zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight))
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
upstreams[i] = &Upstream{Dial: addr}
upstreams[i] = Upstream{Dial: addr}
}

// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
Expand All @@ -170,7 +170,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams,
}

return upstreams, nil
return allNew(upstreams), nil
}

func (su SRVUpstreams) String() string {
Expand Down Expand Up @@ -206,7 +206,7 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string {
type srvLookup struct {
srvUpstreams SRVUpstreams
freshness time.Time
upstreams []*Upstream
upstreams []Upstream
}

func (sl srvLookup) isFresh() bool {
Expand Down Expand Up @@ -325,7 +325,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
cached := aAaaa[auStr]
aAaaaMu.RUnlock()
if cached.isFresh() {
return cached.upstreams, nil
return allNew(cached.upstreams), nil
}

// otherwise, obtain a write-lock to update the cached value
Expand All @@ -337,7 +337,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock
cached = aAaaa[auStr]
if cached.isFresh() {
return cached.upstreams, nil
return allNew(cached.upstreams), nil
}

name := repl.ReplaceAll(au.Name, "")
Expand All @@ -348,15 +348,15 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
return nil, err
}

upstreams := make([]*Upstream, len(ips))
upstreams := make([]Upstream, len(ips))
for i, ip := range ips {
upstreams[i] = &Upstream{
upstreams[i] = Upstream{
Dial: net.JoinHostPort(ip.String(), port),
}
}

// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
if cached.freshness.IsZero() && len(srvs) >= 100 {
if cached.freshness.IsZero() && len(aAaaa) >= 100 {
for randomKey := range aAaaa {
delete(aAaaa, randomKey)
break
Expand All @@ -369,15 +369,15 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams,
}

return upstreams, nil
return allNew(upstreams), nil
}

func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) }

type aLookup struct {
aUpstreams AUpstreams
freshness time.Time
upstreams []*Upstream
upstreams []Upstream
}

func (al aLookup) isFresh() bool {
Expand Down Expand Up @@ -483,6 +483,14 @@ func (u *UpstreamResolver) ParseAddresses() error {
return nil
}

func allNew(upstreams []Upstream) []*Upstream {
results := make([]*Upstream, len(upstreams))
for i := range upstreams {
results[i] = &Upstream{Dial: upstreams[i].Dial}
}
return results
}

var (
srvs = make(map[string]srvLookup)
srvsMu sync.RWMutex
Expand Down

0 comments on commit be299bb

Please # to comment.