diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index 6f5b32ce..af062c71 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -13,14 +13,16 @@ import ( ) type peerConfig struct { - Strategy string `json:"strategy"` - MaxFails int `json:"max_fails"` - FailTimeout time.Duration - period time.Duration // the period for live reloading - Nodes []string `json:"nodes"` - group *gost.NodeGroup - baseNodes []gost.Node - stopped chan struct{} + Strategy string `json:"strategy"` + MaxFails int `json:"max_fails"` + FastestCount int `json:"fastest_count"` // topN fastest node count + FailTimeout time.Duration + period time.Duration // the period for live reloading + + Nodes []string `json:"nodes"` + group *gost.NodeGroup + baseNodes []gost.Node + stopped chan struct{} } func newPeerConfig() *peerConfig { @@ -51,6 +53,7 @@ func (cfg *peerConfig) Reload(r io.Reader) error { FailTimeout: cfg.FailTimeout, }, &gost.InvalidFilter{}, + gost.NewFastestFilter(0, cfg.FastestCount), ), gost.WithStrategy(gost.NewStrategy(cfg.Strategy)), ) @@ -125,6 +128,8 @@ func (cfg *peerConfig) parse(r io.Reader) error { cfg.Strategy = ss[1] case "max_fails": cfg.MaxFails, _ = strconv.Atoi(ss[1]) + case "fastest_count": + cfg.FastestCount, _ = strconv.Atoi(ss[1]) case "fail_timeout": cfg.FailTimeout, _ = time.ParseDuration(ss[1]) case "reload": diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 7ebbdf9e..360bc2d1 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -66,6 +66,7 @@ func (r *route) parseChain() (*gost.Chain, error) { FailTimeout: nodes[0].GetDuration("fail_timeout"), }, &gost.InvalidFilter{}, + gost.NewFastestFilter(0, nodes[0].GetInt("fastest_count")), ), gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))), ) @@ -241,6 +242,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { tr = gost.FakeTCPTransporter() case "udp": tr = gost.UDPTransporter() + case "vsock": + tr = gost.VSOCKTransporter() default: tr = gost.TCPTransporter() } @@ -489,6 +492,8 @@ func (r *route) GenRouters() ([]router, error) { chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() } ln, err = gost.TCPListener(node.Addr) + case "vsock": + ln, err = gost.VSOCKListener(node.Addr) case "udp": ln, err = gost.UDPListener(node.Addr, &gost.UDPListenConfig{ TTL: ttl, diff --git a/go.mod b/go.mod index 90a0fada..dba4a6d2 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/klauspost/compress v1.17.6 github.com/miekg/dns v1.1.58 github.com/quic-go/quic-go v0.41.0 + github.com/mdlayher/vsock v1.2.1 github.com/ryanuber/go-glob v1.0.0 github.com/shadowsocks/go-shadowsocks2 v0.1.5 github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 @@ -41,6 +42,8 @@ require ( github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/klauspost/reedsolomon v1.12.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/mdlayher/socket v0.4.1 // indirect + github.com/onsi/ginkgo/v2 v2.8.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/templexxx/cpu v0.1.0 // indirect @@ -54,4 +57,5 @@ require ( golang.org/x/sys v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.17.0 // indirect + golang.org/x/sync v0.1.0 // indirect ) diff --git a/go.sum b/go.sum index 92980335..d9b7f04d 100644 --- a/go.sum +++ b/go.sum @@ -134,8 +134,10 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= + golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= + golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -144,12 +146,14 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= + golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= + golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -168,7 +172,9 @@ golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= + golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= + golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -176,8 +182,10 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= + golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= + golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/http.go b/http.go index 02a7ad32..8ef96e80 100644 --- a/http.go +++ b/http.go @@ -379,7 +379,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. } else { resp.Header = http.Header{} resp.Header.Set("Server", "nginx/1.14.1") - resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + resp.Header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) if resp.StatusCode == http.StatusOK { resp.Header.Set("Connection", "keep-alive") } diff --git a/http2.go b/http2.go index 79cb3819..de152eae 100644 --- a/http2.go +++ b/http2.go @@ -538,7 +538,7 @@ func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp } else { resp.Header = http.Header{} resp.Header.Set("Server", "nginx/1.14.1") - resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + resp.Header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) if resp.ContentLength > 0 { resp.Header.Set("Content-Type", "text/html") } diff --git a/node.go b/node.go index f64afc43..f12c7450 100644 --- a/node.go +++ b/node.go @@ -90,6 +90,7 @@ func ParseNode(s string) (node Node, err error) { case "ftcp": // fake TCP case "dns": case "redu", "redirectu": // UDP tproxy + case "vsock": default: node.Transport = "tcp" } diff --git a/selector.go b/selector.go index bff6d11f..ccac77ea 100644 --- a/selector.go +++ b/selector.go @@ -4,10 +4,13 @@ import ( "errors" "math/rand" "net" + "sort" "strconv" "sync" "sync/atomic" "time" + + "github.com/go-log/log" ) var ( @@ -205,6 +208,94 @@ func (f *FailFilter) String() string { return "fail" } +// FastestFilter filter the fastest node +type FastestFilter struct { + mu sync.Mutex + + pinger *net.Dialer + pingResult map[int]int + pingResultTTL map[int]int64 + + topCount int +} + +func NewFastestFilter(pingTimeOut int, topCount int) *FastestFilter { + if pingTimeOut == 0 { + pingTimeOut = 3000 // 3s + } + return &FastestFilter{ + mu: sync.Mutex{}, + pinger: &net.Dialer{Timeout: time.Millisecond * time.Duration(pingTimeOut)}, + pingResult: make(map[int]int, 0), + pingResultTTL: make(map[int]int64, 0), + topCount: topCount, + } +} + +func (f *FastestFilter) Filter(nodes []Node) []Node { + // disabled + if f.topCount == 0 { + return nodes + } + + // get latency with ttl cache + now := time.Now().Unix() + + var getNodeLatency = func(node Node) int { + f.mu.Lock() + defer f.mu.Unlock() + + if f.pingResultTTL[node.ID] < now { + f.pingResultTTL[node.ID] = now + 5 // tmp + + // get latency + go func(node Node) { + latency := f.doTcpPing(node.Addr) + r := rand.New(rand.NewSource(time.Now().UnixNano())) + ttl := 300 - int64(120*r.Float64()) + + f.mu.Lock() + defer f.mu.Unlock() + + f.pingResult[node.ID] = latency + f.pingResultTTL[node.ID] = now + ttl + }(node) + } + return f.pingResult[node.ID] + } + + // sort + sort.Slice(nodes, func(i, j int) bool { + return getNodeLatency(nodes[i]) < getNodeLatency(nodes[j]) + }) + + // split + if len(nodes) <= f.topCount { + return nodes + } + + return nodes[0:f.topCount] +} + +func (f *FastestFilter) String() string { + return "fastest" +} + +// doTcpPing +func (f *FastestFilter) doTcpPing(address string) int { + start := time.Now() + conn, err := f.pinger.Dial("tcp", address) + elapsed := time.Since(start) + + if err == nil { + _ = conn.Close() + } + + latency := int(elapsed.Milliseconds()) + log.Logf("pingDoTCP: %s, latency: %d", address, latency) + return latency +} + // InvalidFilter filters the invalid node. // A node is invalid if its port is invalid (negative or zero value). type InvalidFilter struct{} diff --git a/selector_test.go b/selector_test.go index 5da667cf..7cbbfd05 100644 --- a/selector_test.go +++ b/selector_test.go @@ -127,6 +127,30 @@ func TestFailFilter(t *testing.T) { } } +func TestFastestFilter(t *testing.T) { + nodes := []Node{ + Node{ID: 1, marker: &failMarker{}, Addr: "1.0.0.1:80"}, + Node{ID: 2, marker: &failMarker{}, Addr: "1.0.0.2:80"}, + Node{ID: 3, marker: &failMarker{}, Addr: "1.0.0.3:80"}, + } + filter := NewFastestFilter(0, 2) + + var print = func(nodes []Node) []string { + var rows []string + for _, node := range nodes { + rows = append(rows, node.Addr) + } + return rows + } + + result1 := filter.Filter(nodes) + t.Logf("result 1: %+v", print(result1)) + + time.Sleep(time.Second) + result2 := filter.Filter(nodes) + t.Logf("result 2: %+v", print(result2)) +} + func TestSelector(t *testing.T) { nodes := []Node{ Node{ID: 1, marker: &failMarker{}}, diff --git a/vsock.go b/vsock.go new file mode 100644 index 00000000..51aa6dea --- /dev/null +++ b/vsock.go @@ -0,0 +1,76 @@ +package gost + +import ( + "net" + "strconv" + + "github.com/mdlayher/vsock" +) + +// vsockTransporter is a raw VSOCK transporter. +type vsockTransporter struct{} + +// VSOCKTransporter creates a raw VSOCK client. +func VSOCKTransporter() Transporter { + return &vsockTransporter{} +} + +func (tr *vsockTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + if opts.Chain == nil { + vAddr, err := parseAddr(addr) + if err != nil { + return nil, err + } + return vsock.Dial(vAddr.ContextID, vAddr.Port, nil) + } + return opts.Chain.Dial(addr) +} + +func parseUint32(s string) (uint32, error ) { + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return 0, err + } + return uint32(n), nil +} + +func parseAddr(addr string) (*vsock.Addr, error) { + hostStr, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + host := uint32(0) + if hostStr != "" { + host, err = parseUint32(hostStr) + if err != nil { + return nil, err + } + } + + port, err := parseUint32(portStr) + if err != nil { + return nil, err + } + return &vsock.Addr{ContextID: host, Port: port}, nil +} + +func (tr *vsockTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *vsockTransporter) Multiplex() bool { + return false +} + +// VSOCKListener creates a Listener for VSOCK proxy server. +func VSOCKListener(addr string) (Listener, error) { + vAddr, err := parseAddr(addr) + if err != nil { + return nil, err + } + return vsock.Listen(vAddr.Port, nil) +}