-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrequest_response_server.go
225 lines (171 loc) · 7.7 KB
/
request_response_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
package pbreqres
import (
"fmt"
"log"
"net/http"
"io"
"io/ioutil"
"sync"
"strings"
"strconv"
"mime"
"path"
)
const RequestPrefix = "Pb-H-"
const ResponsePrefix = "Pb-H-"
type PatchedRequest struct {
uri string
httpRequest *http.Request
responseChan chan PatchedReponse
}
type PatchedReponse struct {
//body io.ReadCloser
responderRequest *http.Request
doneSignal chan struct{}
}
type RequestResponseServer struct {
channels map[string]chan PatchedRequest
mutex *sync.Mutex
}
func NewRequestResponseServer() *RequestResponseServer {
channels := make(map[string]chan PatchedRequest)
mutex := &sync.Mutex{}
server := &RequestResponseServer{channels, mutex}
return server
}
func (s *RequestResponseServer) Handle(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
var channelId string
pathParts := strings.Split(r.URL.Path, "/")
if len(pathParts) < 2 {
w.WriteHeader(400)
w.Write([]byte("Invalid channel. Must start with '/'"))
return
}
protocol := pathParts[1]
isResponder := false
// TODO: For API version 1.0, change to default root path to being an
// alias of /req/ instead of /mpmc/
if protocol == "res" {
isResponder = true
if len(pathParts) < 3 {
w.WriteHeader(400)
w.Write([]byte("Invalid responder channel. Must be of the form /res/<channel-id>"))
return
}
channelId = pathParts[2]
} else if protocol == "req" {
if len(pathParts) < 3 {
w.WriteHeader(400)
w.Write([]byte("Invalid requester channel. Must be of the form /req/<channel-id>"))
return
}
channelId = pathParts[2]
} else {
// default to /mpmc/
channelId = r.URL.Path
if r.Method == "POST" {
isResponder = true
}
}
s.mutex.Lock()
_, ok := s.channels[channelId]
if !ok {
s.channels[channelId] = make(chan PatchedRequest)
}
channel := s.channels[channelId]
s.mutex.Unlock()
if isResponder {
log.Println("responder connection")
switchChannel := query.Get("switch") == "true"
// not all HTTP clients can read the response headers before sending the request body.
// switching splits the transaction across 2 requests, providing the responder
// with a random channel for the second request, and connecting that to the original
// requester. These are referred to as "double clutch" requests.
var switchChannelIdStr string
if switchChannel {
switchChannelId, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Fatal(err)
}
switchChannelIdStr = string(switchChannelId)
if switchChannelIdStr == "" {
w.WriteHeader(400)
w.Write([]byte("No switching channel provided"))
return
}
switchChannelIdStr = "/" + switchChannelIdStr
}
select {
case request := <-channel:
w.Header().Add("Pb-Uri", request.uri)
w.Header().Add("Pb-Method", request.httpRequest.Method)
for k, vList := range request.httpRequest.Header {
for _, v := range vList {
w.Header().Add(RequestPrefix + k, v)
}
}
// not all HTTP clients can read the response headers before sending the request body.
// switching splits the transaction across 2 requests, providing the responder
// with a random channel for the second request, and connecting that to the original
// requester. These are referred to as "double clutch" requests.
if switchChannel {
io.Copy(w, request.httpRequest.Body)
ch := make(chan PatchedRequest, 1)
s.mutex.Lock()
s.channels[switchChannelIdStr] = ch
s.mutex.Unlock()
fmt.Println("switched it to", switchChannelIdStr)
ch <- request
} else {
doneSignal := make(chan struct{})
response := PatchedReponse{responderRequest: r, doneSignal: doneSignal}
contentType := mime.TypeByExtension(path.Ext(r.URL.Path))
if contentType != "" {
response.responderRequest.Header.Set(ResponsePrefix + "Content-Type", contentType)
}
io.Copy(w, request.httpRequest.Body)
request.responseChan <- response
<-doneSignal
}
case <-r.Context().Done():
log.Println("responder canceled")
}
} else {
log.Println("requester connection")
responseChan := make(chan PatchedReponse)
request := PatchedRequest{uri: r.URL.RequestURI(), httpRequest: r, responseChan: responseChan}
select {
case channel <- request:
response := <-responseChan
var status int
for k, vList := range response.responderRequest.Header {
if strings.HasPrefix(k, ResponsePrefix) {
// strip the prefix
headerName := k[len(ResponsePrefix):]
for _, v := range vList {
w.Header().Add(headerName, v)
}
} else if strings.HasPrefix(k, "Pb-Status") {
var err error
status, err = strconv.Atoi(vList[0])
if err != nil {
log.Fatal(err)
}
}
}
if status != 0 {
w.WriteHeader(status)
}
go func() {
<-r.Context().Done()
fmt.Println("requester canceled after connection")
response.responderRequest.Body.Close()
}()
io.Copy(w, response.responderRequest.Body)
close(response.doneSignal)
case <-r.Context().Done():
log.Println("requester canceled before connection")
}
}
}