Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: add batch request support #99

Merged
merged 3 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
cb(w)
}

var req request
// We read the entire request upfront in a buffer to be able to tell if the
// client sent more than maxRequestSize and report it back as an explicit error,
// instead of just silently truncating it and reporting a more vague parsing
Expand All @@ -205,11 +204,11 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
if err != nil {
// ReadFrom will discard EOF so any error here is unexpected and should
// be reported.
rpcError(wf, &req, rpcParseError, xerrors.Errorf("reading request: %w", err))
rpcError(wf, nil, rpcParseError, xerrors.Errorf("reading request: %w", err))
return
}
if reqSize > s.maxRequestSize {
rpcError(wf, &req, rpcParseError,
rpcError(wf, nil, rpcParseError,
// rpcParseError is the closest we have from the standard errors defined
// in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object)
// to report the maximum limit.
Expand All @@ -218,17 +217,52 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
return
}

if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err))
if reqSize == 0 {
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
return
}

if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}
if bufferedRequest.Bytes()[0] == '[' && bufferedRequest.Bytes()[reqSize-1] == ']' {
var reqs []request

if err := json.NewDecoder(bufferedRequest).Decode(&reqs); err != nil {
rpcError(wf, nil, rpcParseError, xerrors.New("Parse error"))
return
}

if len(reqs) == 0 {
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
return
}

w.Write([]byte("["))
for idx, req := range reqs {
if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}

s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)

if idx != len(reqs)-1 {
w.Write([]byte(","))
}
}
w.Write([]byte("]"))
} else {
var req request
if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.New("Parse error"))
return
}

if req.ID, err = normalizeID(req.ID); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
return
}

s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
}
}

func doCall(methodName string, f reflect.Value, params []reflect.Value) (out []reflect.Value, err error) {
Expand Down
29 changes: 27 additions & 2 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ func TestRawRequests(t *testing.T) {
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

removeSpaces := func(jsonStr string) (string, error) {
var jsonObj interface{}
err := json.Unmarshal([]byte(jsonStr), &jsonObj)
if err != nil {
return "", err
}

compactJSONBytes, err := json.Marshal(jsonObj)
if err != nil {
return "", err
}

return string(compactJSONBytes), nil
}

tc := func(req, resp string, n int32) func(t *testing.T) {
return func(t *testing.T) {
rpcHandler.n = 0
Expand All @@ -100,7 +115,13 @@ func TestRawRequests(t *testing.T) {
b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)

assert.Equal(t, resp, strings.TrimSpace(string(b)))
expectedResp, err := removeSpaces(resp)
require.NoError(t, err)

responseBody, err := removeSpaces(string(b))
require.NoError(t, err)

assert.Equal(t, expectedResp, responseBody)
require.Equal(t, n, rpcHandler.n)
}
}
Expand All @@ -109,7 +130,11 @@ func TestRawRequests(t *testing.T) {
t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1))
t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1))
t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10))

// Batch requests
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 5}`, `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"Parse error"}}`, 0))
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 6}]`, `[{"jsonrpc":"2.0","id":6}]`, 123))
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 7},{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-122], "id": 8}]`, `[{"jsonrpc":"2.0","id":7},{"jsonrpc":"2.0","id":8}]`, 1))
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 9},{"jsonrpc": "2.0", "params": [-122], "id": 10}]`, `[{"jsonrpc":"2.0","id":9},{"error":{"code":-32601,"message":"method '' not found"},"id":10,"jsonrpc":"2.0"}]`, 123))
}

func TestReconnection(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

const (
rpcParseError = -32700
rpcInvalidRequest = -32600
rpcMethodNotFound = -32601
rpcInvalidParams = -32602
)
Expand Down Expand Up @@ -107,13 +108,13 @@ func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error)
log.Errorf("RPC Error: %s", err)
wf(func(w io.Writer) {
if hw, ok := w.(http.ResponseWriter); ok {
hw.WriteHeader(500)
hw.WriteHeader(200)
}

log.Warnf("rpc error: %s", err)

if req.ID == nil { // notification
return
if req == nil {
req = &request{}
}

resp := response{
Expand Down