-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathproxy.go
337 lines (280 loc) · 7.85 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
package vink
import (
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
type (
// 代理层
Proxy struct {
proxyProtocol
sideServers [SIDE_NUM]*EchoServer
proxySessionId uint32
sessionhubs [connSlots][PEER_NUM]*sessionHub
gVConnId uint32
peerMapList [connSlots]map[uint32]*Peers
peerMapLock [connSlots]sync.RWMutex
}
Peers struct {
cs [PEER_NUM]*Session
}
ProxyConfig struct {
MaxConn int
MsgBuffSize int
MsgSendChanSize int
ProxyIdleTimeout time.Duration
AuthKey string
}
)
func (ps *Peers) GetPeers() [2]*Session {
return ps.cs
}
func NewProxy(maxPacketSize int) *Proxy {
proxy := new(Proxy)
proxy.maxPacketSize = maxPacketSize
//初始化分配
for i := 0; i < connSlots; i++ {
proxy.peerMapList[i] = make(map[uint32]*Peers)
proxy.sessionhubs[i][CLIENT_SIDE] = NewSessionHub()
proxy.sessionhubs[i][SERVER_SIDE] = NewSessionHub()
}
return proxy
}
//针对服务端的服务
func (p *Proxy) ServiceSideServ(ln net.Listener, cfg ProxyConfig) {
p.sideServers[SERVER_SIDE] = NewEchoServer(ln,
ProtocolFunc(func(rw io.ReadWriter) (Codec, error) {
serverId, err := p.sendAuth(rw.(net.Conn), []byte(cfg.AuthKey))
if err != nil {
GetLogger().Errorf("error accept server from %s: %s", rw.(net.Conn).RemoteAddr(), err)
return nil, err
}
GetLogger().Infof("accept server %d from %s", serverId, rw.(net.Conn).RemoteAddr())
return p.newProxyCodec(serverId, rw.(net.Conn), cfg.MsgBuffSize), nil
}),
HandlerFunc(func(session *Session) {
p.handleSession(session, SERVER_SIDE, 0, cfg.ProxyIdleTimeout)
}),
cfg.MsgSendChanSize)
p.sideServers[SERVER_SIDE].Serve()
}
//针对客户端的服务
func (p *Proxy) ClientSideServ(ln net.Listener, cfg ProxyConfig) {
p.sideServers[CLIENT_SIDE] = NewEchoServer(ln,
ProtocolFunc(func(rw io.ReadWriter) (Codec, error) {
fid := atomic.AddUint32(&p.proxySessionId, 1)
return p.newProxyCodec(fid, rw.(net.Conn), cfg.MsgBuffSize), nil
}),
HandlerFunc(func(session *Session) {
p.handleSession(session, CLIENT_SIDE, cfg.MaxConn, cfg.ProxyIdleTimeout)
}),
cfg.MsgSendChanSize)
p.sideServers[CLIENT_SIDE].Serve()
}
//处理会话信息
func (p *Proxy) handleSession(session *Session, side, maxConn int, idleTimeout time.Duration) {
fid := session.codec.(*proxyCodec).id
record := p.newSessionRecord(fid, session)
session.snapshot = record
p.addconnSession(fid, side, session)
defer func() {
record.Delete()
if err := recover(); err != nil {
GetLogger().Error("Proxy Panic:", err)
}
}()
otherside := 1 &^ side
GetLogger().Infof("side = %d , othersize = %d ", side, otherside)
for {
//设置最大空闲处理
if idleTimeout > 0 {
err := session.codec.(*proxyCodec).conn.SetDeadline(time.Now().Add(idleTimeout))
if err != nil {
GetLogger().Error(err)
return
}
}
buf, err := session.Receive()
if err != nil {
//GetLogger().Error(err)
return
}
msg := *(buf.(*[]byte))
connId := p.decodeCmdFid(msg)
if connId == 0 {
p.processCmd(msg, session, side, maxConn)
continue
}
//获取对端信息
peers := p.getPeers(connId)
if peers.GetPeers()[side] == nil || peers.GetPeers()[otherside] == nil {
msg = msg[:0]
p.send(session, p.encodeCloseCmd(connId))
continue
}
if peers.GetPeers()[side] != session {
msg = msg[:0]
panic("peer session info not match")
}
p.send(peers.GetPeers()[otherside], msg)
}
}
func (p *Proxy) processCmd(msg []byte, session *Session, side, maxConn int) {
otherside := 1 &^ side
cmd := p.decodeCmdType(msg)
GetLogger().Debugf("%v | proxy processCmd = %d", msg, cmd)
switch cmd {
case CMD_DIAL:
remoteId := p.decodeDialCmd(msg)
msg = msg[:0]
var peers [2]*Session
peers[side] = session
peers[otherside] = p.getConnSession(remoteId, otherside)
info := &Peers{peers}
if peers[otherside] == nil || !p.acceptPeers(info, session, maxConn) {
p.send(session, p.encodeRefuseCmd(remoteId))
}
case CMD_CLOSE:
connId := p.decodeCloseCmd(msg)
msg = msg[:0]
p.closePeers(connId)
case CMD_PING:
msg = msg[:0]
p.send(session, p.encodePingCmd())
default:
msg = msg[:0]
panic(fmt.Sprintf("unsupported proxy command : %d", cmd))
}
}
func (p *Proxy) Stop() {
p.sideServers[SERVER_SIDE].Stop()
p.sideServers[CLIENT_SIDE].Stop()
}
//会话快照信息
type SessionRecord struct {
sync.Mutex
fid uint32
proxy *Proxy
session *Session
lastActive int64
hbChan chan struct{}
connsMap map[uint32]struct{}
deleteOnce sync.Once
deleted bool
}
func (p *Proxy) newSessionRecord(fid uint32, session *Session) *SessionRecord {
return &SessionRecord{
fid: fid,
session: session,
proxy: p,
hbChan: make(chan struct{}),
connsMap: make(map[uint32]struct{}),
}
}
func (record *SessionRecord) Delete() {
record.deleteOnce.Do(func() {
record.session.Close()
record.Lock()
record.deleted = true
record.Unlock()
//关闭与之关联的peer虚连接
for fid := range record.connsMap {
record.proxy.closePeers(fid)
}
})
}
//put到channel后会在session中注册一个closedcallback
//session关闭后,会自动清除这个关系表
func (p *Proxy) addconnSession(connId uint32, side int, session *Session) {
p.sessionhubs[connId%connSlots][side].Put(connId, session)
}
func (p *Proxy) getConnSession(connId uint32, side int) *Session {
return p.sessionhubs[connId%connSlots][side].Get(connId)
}
//添加虚链接
func (p *Proxy) addPeers(connId uint32, peers *Peers) {
slotIndex := connId % connSlots
p.peerMapLock[slotIndex].Lock()
defer p.peerMapLock[slotIndex].Unlock()
if _, exists := p.peerMapList[slotIndex][connId]; exists {
panic("virtual connection already exists")
}
p.peerMapList[slotIndex][connId] = peers
}
//获取虚链接
func (p *Proxy) getPeers(connId uint32) *Peers {
slotIndex := connId % connSlots
p.peerMapLock[slotIndex].Lock()
defer p.peerMapLock[slotIndex].Unlock()
return p.peerMapList[slotIndex][connId]
}
//删除虚链接
func (p *Proxy) delPeers(connId uint32) (*Peers, bool) {
slotIndex := connId % connSlots
p.peerMapLock[slotIndex].Lock()
pi, ok := p.peerMapList[slotIndex][connId]
//如果虚链接的结构全部已经清理了,是 非ok, 这样返回 false
if ok {
delete(p.peerMapList[slotIndex], connId)
}
p.peerMapLock[slotIndex].Unlock()
return pi, ok
}
//接受虚链接请求
func (p *Proxy) acceptPeers(peers *Peers, session *Session, maxConn int) bool {
var connId uint32
for connId == 0 {
connId = atomic.AddUint32(&p.gVConnId, 1)
}
for i := 0; i < PEER_NUM; i++ {
record := peers.GetPeers()[i].snapshot.(*SessionRecord)
record.Lock()
defer record.Unlock()
if record.deleted {
GetLogger().Error("proxy was deleted")
return false
}
if peers.GetPeers()[i] == session && maxConn != 0 && len(record.connsMap) >= maxConn {
GetLogger().Error("session connections size over maxconn")
return false
}
if _, exists := record.connsMap[connId]; exists {
panic("peers connection already exists")
}
record.connsMap[connId] = struct{}{}
}
//更新虚链接信息
p.addPeers(connId, peers)
pcount := len(peers.GetPeers())
for i := 0; i < pcount; i++ {
//other size cid
remoteId := peers.GetPeers()[(i+1)%pcount].snapshot.(*SessionRecord).fid
if peers.GetPeers()[i] == session {
p.send(peers.GetPeers()[i], p.encodeAcceptCmd(connId, remoteId))
} else {
p.send(peers.GetPeers()[i], p.encodeConnectCmd(connId, remoteId))
}
}
return true
}
//关闭虚链接
func (p *Proxy) closePeers(connId uint32) {
peers, ok := p.delPeers(connId)
if !ok {
return
}
for i := 0; i < len(peers.GetPeers()); i++ {
record := peers.GetPeers()[i].snapshot.(*SessionRecord)
record.Lock()
defer record.Unlock()
if record.deleted {
continue
}
//只清集合中的信息,和发送清理信息,不直接对session有任何关闭操作
delete(record.connsMap, connId)
p.send(peers.GetPeers()[i], p.encodeCloseCmd(connId))
}
}