From 2faecc1be8d72ee5eb1a24a671e2cfa75f8093df Mon Sep 17 00:00:00 2001 From: tiancheng91 Date: Mon, 16 Jan 2023 15:15:24 +0800 Subject: [PATCH] feat: support node sort by tcp ping latency --- cmd/gost/peer.go | 21 ++++++----- cmd/gost/route.go | 1 + selector.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++ selector_test.go | 24 +++++++++++++ 4 files changed, 127 insertions(+), 8 deletions(-) diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index 6f5b32ce..af062c71 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -13,14 +13,16 @@ import ( ) type peerConfig struct { - Strategy string `json:"strategy"` - MaxFails int `json:"max_fails"` - FailTimeout time.Duration - period time.Duration // the period for live reloading - Nodes []string `json:"nodes"` - group *gost.NodeGroup - baseNodes []gost.Node - stopped chan struct{} + Strategy string `json:"strategy"` + MaxFails int `json:"max_fails"` + FastestCount int `json:"fastest_count"` // topN fastest node count + FailTimeout time.Duration + period time.Duration // the period for live reloading + + Nodes []string `json:"nodes"` + group *gost.NodeGroup + baseNodes []gost.Node + stopped chan struct{} } func newPeerConfig() *peerConfig { @@ -51,6 +53,7 @@ func (cfg *peerConfig) Reload(r io.Reader) error { FailTimeout: cfg.FailTimeout, }, &gost.InvalidFilter{}, + gost.NewFastestFilter(0, cfg.FastestCount), ), gost.WithStrategy(gost.NewStrategy(cfg.Strategy)), ) @@ -125,6 +128,8 @@ func (cfg *peerConfig) parse(r io.Reader) error { cfg.Strategy = ss[1] case "max_fails": cfg.MaxFails, _ = strconv.Atoi(ss[1]) + case "fastest_count": + cfg.FastestCount, _ = strconv.Atoi(ss[1]) case "fail_timeout": cfg.FailTimeout, _ = time.ParseDuration(ss[1]) case "reload": diff --git a/cmd/gost/route.go b/cmd/gost/route.go index e6f0c633..360bc2d1 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -66,6 +66,7 @@ func (r *route) parseChain() (*gost.Chain, error) { FailTimeout: nodes[0].GetDuration("fail_timeout"), }, &gost.InvalidFilter{}, + gost.NewFastestFilter(0, nodes[0].GetInt("fastest_count")), ), gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))), ) diff --git a/selector.go b/selector.go index bff6d11f..a5f68b91 100644 --- a/selector.go +++ b/selector.go @@ -4,10 +4,13 @@ import ( "errors" "math/rand" "net" + "sort" "strconv" "sync" "sync/atomic" "time" + + "github.com/go-log/log" ) var ( @@ -205,6 +208,92 @@ func (f *FailFilter) String() string { return "fail" } +// FastestFilter filter the fastest node +type FastestFilter struct { + mu sync.Mutex + + pinger *net.Dialer + pingResult map[int]int + pingResultTTL map[int]int64 + + topCount int +} + +func NewFastestFilter(pingTimeOut int, topCount int) *FastestFilter { + if pingTimeOut == 0 { + pingTimeOut = 3000 // 3s + } + return &FastestFilter{ + mu: sync.Mutex{}, + pinger: &net.Dialer{Timeout: time.Millisecond * time.Duration(pingTimeOut)}, + pingResult: make(map[int]int, 0), + pingResultTTL: make(map[int]int64, 0), + topCount: topCount, + } +} + +func (f *FastestFilter) Filter(nodes []Node) []Node { + // disabled + if f.topCount == 0 { + return nodes + } + + // get latency with ttl cache + now := time.Now().Unix() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + var getNodeLatency = func(node Node) int { + if f.pingResultTTL[node.ID] < now { + f.mu.Lock() + f.pingResultTTL[node.ID] = now + 5 // tmp + defer f.mu.Unlock() + + // get latency + go func(node Node) { + latency := f.doTcpPing(node.Addr) + ttl := 300 - int64(60*r.Float64()) + + f.mu.Lock() + f.pingResult[node.ID] = latency + f.pingResultTTL[node.ID] = now + ttl + defer f.mu.Unlock() + }(node) + } + return f.pingResult[node.ID] + } + + // sort + sort.Slice(nodes, func(i, j int) bool { + return getNodeLatency(nodes[i]) < getNodeLatency(nodes[j]) + }) + + // split + if len(nodes) <= f.topCount { + return nodes + } + + return nodes[0:f.topCount] +} + +func (f *FastestFilter) String() string { + return "fastest" +} + +// doTcpPing +func (f *FastestFilter) doTcpPing(address string) int { + start := time.Now() + conn, err := f.pinger.Dial("tcp", address) + elapsed := time.Since(start) + + if err == nil { + _ = conn.Close() + } + + latency := int(elapsed.Milliseconds()) + log.Logf("pingDoTCP: %s, latency: %d", address, latency) + return latency +} + // InvalidFilter filters the invalid node. // A node is invalid if its port is invalid (negative or zero value). type InvalidFilter struct{} diff --git a/selector_test.go b/selector_test.go index 5da667cf..7cbbfd05 100644 --- a/selector_test.go +++ b/selector_test.go @@ -127,6 +127,30 @@ func TestFailFilter(t *testing.T) { } } +func TestFastestFilter(t *testing.T) { + nodes := []Node{ + Node{ID: 1, marker: &failMarker{}, Addr: "1.0.0.1:80"}, + Node{ID: 2, marker: &failMarker{}, Addr: "1.0.0.2:80"}, + Node{ID: 3, marker: &failMarker{}, Addr: "1.0.0.3:80"}, + } + filter := NewFastestFilter(0, 2) + + var print = func(nodes []Node) []string { + var rows []string + for _, node := range nodes { + rows = append(rows, node.Addr) + } + return rows + } + + result1 := filter.Filter(nodes) + t.Logf("result 1: %+v", print(result1)) + + time.Sleep(time.Second) + result2 := filter.Filter(nodes) + t.Logf("result 2: %+v", print(result2)) +} + func TestSelector(t *testing.T) { nodes := []Node{ Node{ID: 1, marker: &failMarker{}},