-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssh_tunnel.go
199 lines (170 loc) · 5.15 KB
/
ssh_tunnel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
package mgoc
import (
"encoding/json"
"fmt"
"github.com/civet148/gotools/wss"
_ "github.com/civet148/gotools/wss/tcpsock" //required (register socket instance)
"github.com/civet148/log"
"github.com/elliotchance/sshtunnel"
"golang.org/x/crypto/ssh"
"net"
"strings"
)
const (
defaultSshTunnelPort = 6033 //
)
type dsnDriver struct {
host string //ip:port
ip string //ip
port string //port
user string
password string
db string
charset string
slave bool
max int
idle int
strDSN string
queries map[string]string
}
type SSH struct {
User string //SSH tunnel server login account
Password string //SSH tunnel server login password
PrivateKey string //SSH tunnel server private key, eg. "/home/test/.ssh/private-key.pem"
Host string //SSH tunnel server host [ip or domain], default port 22 if not specified
listenPort int //SSH transfer port of localhost
listenIP string //SSH transfer ip of localhost
tunnel *sshtunnel.SSHTunnel //SSH tunnel instance
authMethod ssh.AuthMethod //SSH tunnel auth method
}
func (s *SSH) GoString() string {
return s.String()
}
func (s *SSH) String() string {
data, _ := json.Marshal(s)
return string(data)
}
func (s *SSH) setDefaultPort() {
if -1 == strings.Index(s.Host, ":") {
s.Host += ":22"
}
}
func (s *SSH) openSSHTunnel(strMgoUrl string) (strDSN string, err error) {
var dsn dsnDriver
ui := ParseUrl(strMgoUrl)
if err != nil {
return "", log.Errorf(err.Error())
}
dsn = getDsnDriver(ui)
if ok := s.startSSHTunnel(&dsn); !ok {
err := fmt.Errorf("start SSH tunnel service failed, please make sure your tunnel server config [%+v] is correct", s)
return "", log.Errorf(err.Error())
}
strDSN = s.buildTunnelDSN(&dsn)
//log.Debugf("open SSH tunnel [%+v] tunnel DSN driver [%+v]", s, strDSN)
return
}
func (s *SSH) startSSHTunnel(dsn *dsnDriver) (ok bool) {
s.setDefaultPort()
//log.Debugf("try connect to SSH tunnel server [%s] ok", s.Host)
if s.listenPort, ok = s.tryListenRandomPort(); !ok {
log.Errorf("try listen random port for SSH tunnel failed")
return
}
if s.PrivateKey != "" {
s.authMethod = sshtunnel.PrivateKeyFile(s.PrivateKey)
} else {
s.authMethod = ssh.Password(s.Password)
}
var strTunnelHost = s.User + "@" + s.Host
s.tunnel = sshtunnel.NewSSHTunnel(
// User and host of tunnel server, it will default to port 22 if not specified.
strTunnelHost,
// Pick ONE of the following authentication methods:
// 1. ssh.Password("123456")
// 2. sshtunnel.PrivateKeyFile("path/to/private/key.pem")
s.authMethod,
// The destination host and port of the actual server.
dsn.host,
// The local port you want to bind the remote port to. specifying "0" will lead to a random port.
fmt.Sprintf("%d", s.listenPort),
)
//make sure tunnel server is reachable
if err := s.tryConnectSSH(); err != nil {
log.Errorf("try connect to SSH tunnel server [%s] failed", s.Host)
return false
}
//start tunnel service
go s.start()
//make sure local tunnel transfer port is ready
var strLocalAddr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
if err := s.tryConnect(strLocalAddr); err != nil {
log.Errorf("connect to transfer address [%s] error", strLocalAddr)
return false
}
//log.Debugf("try connect to transfer address [%s] ok", strLocalAddr)
return true
}
func (s *SSH) start() (err error) {
//s.tunnel.Log = logger.New(os.Stdout, "", logger.Ldate|logger.Lmicroseconds)
if err = s.tunnel.Start(); err != nil {
log.Errorf("start tunnel service error [%s]", err)
panic(err.Error())
}
return
}
func (s *SSH) tryConnectSSH() (err error) {
var c *ssh.Client
if c, err = ssh.Dial("tcp", s.Host, s.tunnel.Config); err != nil {
log.Errorf("try connect ssh [%s] config [%+v] error [%s]", s.Host, s.tunnel.Config, err.Error())
return
}
defer c.Close()
return
}
func (s *SSH) tryConnect(strHost string) (err error) {
var strConnect = fmt.Sprintf("tcp://%s", strHost)
c := wss.NewClient()
if err = c.Connect(strConnect); err != nil {
log.Errorf("connect to [%s] error", strHost)
return
}
_ = c.Close()
return
}
func (s *SSH) tryListenRandomPort() (port int, ok bool) {
s.listenIP = "127.0.0.1"
port = defaultSshTunnelPort
for i := 0; i < 1000; i++ {
port += i
var strListenAddr = fmt.Sprintf("%s:%d", s.listenIP, port)
if listener, err := net.Listen("tcp", strListenAddr); err == nil {
//log.Debugf("tunnel service try listen [%s] ok", strListenAddr)
_ = listener.Close()
ok = true
break
}
}
return
}
func (s *SSH) buildTunnelDSN(d *dsnDriver) (strDSN string) {
var kvs []string
strDSN = fmt.Sprintf("mongodb://%s:%s@%s:%d/%s?", d.user, d.password, s.listenIP, s.listenPort, d.db)
for k, v := range d.queries {
kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
}
if len(kvs) > 0 {
strDSN += strings.Join(kvs, "&")
}
//log.Debugf("dsn driver [%+v] ssh [%+v] new DSN [%s]", d, s, strDSN)
return strDSN
}
func getDsnDriver(ui *UrlInfo) (d dsnDriver) {
d.user = ui.User
d.host = ui.Host
d.ip, d.port = getHostPort(d.host)
d.password = ui.Password
d.db = getDatabaseName(ui.Path)
d.queries = ui.Queries
return d
}