From 5efd7bd73e11fea58d1c7f1c110902e78a286299 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 10 Oct 2023 14:05:12 -0700 Subject: [PATCH] server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703) (#6708) --- internal/transport/http2_server.go | 11 +-- internal/transport/transport_test.go | 35 +++++---- server.go | 69 ++++++++++++----- server_ext_test.go | 110 +++++++++++++++++++++++++++ 4 files changed, 180 insertions(+), 45 deletions(-) create mode 100644 server_ext_test.go diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 79e86ba08836..ec4eef21342a 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, ID: http2.SettingMaxFrameSize, Val: http2MaxFrameLen, }} - // TODO(zhaoq): Have a better way to signal "no limit" because 0 is - // permitted in the HTTP2 spec. - maxStreams := config.MaxStreams - if maxStreams == 0 { - maxStreams = math.MaxUint32 - } else { + if config.MaxStreams != math.MaxUint32 { isettings = append(isettings, http2.Setting{ ID: http2.SettingMaxConcurrentStreams, - Val: maxStreams, + Val: config.MaxStreams, }) } dynamicWindow := true @@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, framer: framer, readerDone: make(chan struct{}), writerDone: make(chan struct{}), - maxStreams: maxStreams, + maxStreams: config.MaxStreams, inTapHandle: config.InTapHandle, fc: &trInFlow{limit: uint32(icwz)}, state: reachable, diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c0d85b2a88d8..8c04d2f5c44d 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -336,6 +336,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT if err != nil { return } + if serverConfig.MaxStreams == 0 { + serverConfig.MaxStreams = math.MaxUint32 + } transport, err := NewServerTransport(conn, serverConfig) if err != nil { return @@ -442,8 +445,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server return server } -func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) { - return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}) +func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) { + return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{}) } func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { @@ -538,7 +541,7 @@ func (s) TestInflightStreamClosing(t *testing.T) { // Tests that when streamID > MaxStreamId, the current client transport drains. func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() defer server.stop() callHdr := &CallHdr{ @@ -583,7 +586,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { } func (s) TestClientSendAndReceive(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -623,7 +626,7 @@ func (s) TestClientSendAndReceive(t *testing.T) { } func (s) TestClientErrorNotify(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() go server.stop() // ct.reader should detect the error and activate ct.Error(). @@ -657,7 +660,7 @@ func performOneRPC(ct ClientTransport) { } func (s) TestClientMix(t *testing.T) { - s, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + s, ct, cancel := setUp(t, 0, normal) defer cancel() time.AfterFunc(time.Second, s.stop) go func(ct ClientTransport) { @@ -671,7 +674,7 @@ func (s) TestClientMix(t *testing.T) { } func (s) TestLargeMessage(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -806,7 +809,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) { // proceed until they complete naturally, while not allowing creation of new // streams during this window. func (s) TestGracefulClose(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong) + server, ct, cancel := setUp(t, 0, pingpong) defer cancel() defer func() { // Stop the server's listener to make the server's goroutines terminate @@ -872,7 +875,7 @@ func (s) TestGracefulClose(t *testing.T) { } func (s) TestLargeMessageSuspension(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) + server, ct, cancel := setUp(t, 0, suspended) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -980,7 +983,7 @@ func (s) TestMaxStreams(t *testing.T) { } func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) + server, ct, cancel := setUp(t, 0, suspended) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1452,7 +1455,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { var encodingTestStatus = status.New(codes.Internal, "\n") func (s) TestEncodingRequiredStatus(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) + server, ct, cancel := setUp(t, 0, encodingRequiredStatus) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1480,7 +1483,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) { } func (s) TestInvalidHeaderField(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) + server, ct, cancel := setUp(t, 0, invalidHeaderField) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1502,7 +1505,7 @@ func (s) TestInvalidHeaderField(t *testing.T) { } func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) + server, ct, cancel := setUp(t, 0, invalidHeaderField) defer cancel() defer server.stop() defer ct.Close(fmt.Errorf("closed manually by test")) @@ -2170,7 +2173,7 @@ func (s) TestPingPong1MB(t *testing.T) { // This is a stress-test of flow control logic. func runPingPongTest(t *testing.T, msgSize int) { - server, client, cancel := setUp(t, 0, 0, pingpong) + server, client, cancel := setUp(t, 0, pingpong) defer cancel() defer server.stop() defer client.Close(fmt.Errorf("closed manually by test")) @@ -2252,7 +2255,7 @@ func (s) TestHeaderTblSize(t *testing.T) { } }() - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() defer ct.Close(fmt.Errorf("closed manually by test")) defer server.stop() @@ -2611,7 +2614,7 @@ func TestConnectionError_Unwrap(t *testing.T) { func (s) TestPeerSetInServerContext(t *testing.T) { // create client and server transports. - server, client, cancel := setUp(t, 0, math.MaxUint32, normal) + server, client, cancel := setUp(t, 0, normal) defer cancel() defer server.stop() defer client.Close(fmt.Errorf("closed manually by test")) diff --git a/server.go b/server.go index 81969e7c15a9..8869cc906f25 100644 --- a/server.go +++ b/server.go @@ -115,12 +115,6 @@ type serviceInfo struct { mdata interface{} } -type serverWorkerData struct { - st transport.ServerTransport - wg *sync.WaitGroup - stream *transport.Stream -} - // Server is a gRPC server to serve RPC requests. type Server struct { opts serverOptions @@ -145,7 +139,7 @@ type Server struct { channelzID *channelz.Identifier czData *channelzData - serverWorkerChannel chan *serverWorkerData + serverWorkerChannel chan func() } type serverOptions struct { @@ -177,6 +171,7 @@ type serverOptions struct { } var defaultServerOptions = serverOptions{ + maxConcurrentStreams: math.MaxUint32, maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, maxSendMessageSize: defaultServerMaxSendMessageSize, connectionTimeout: 120 * time.Second, @@ -387,6 +382,9 @@ func MaxSendMsgSize(m int) ServerOption { // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { + if n == 0 { + n = math.MaxUint32 + } return newFuncServerOption(func(o *serverOptions) { o.maxConcurrentStreams = n }) @@ -567,24 +565,19 @@ const serverWorkerResetThreshold = 1 << 16 // [1] https://github.com/golang/go/issues/18138 func (s *Server) serverWorker() { for completed := 0; completed < serverWorkerResetThreshold; completed++ { - data, ok := <-s.serverWorkerChannel + f, ok := <-s.serverWorkerChannel if !ok { return } - s.handleSingleStream(data) + f() } go s.serverWorker() } -func (s *Server) handleSingleStream(data *serverWorkerData) { - defer data.wg.Done() - s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream)) -} - // initServerWorkers creates worker goroutines and a channel to process incoming // connections to reduce the time spent overall on runtime.morestack. func (s *Server) initServerWorkers() { - s.serverWorkerChannel = make(chan *serverWorkerData) + s.serverWorkerChannel = make(chan func()) for i := uint32(0); i < s.opts.numServerWorkers; i++ { go s.serverWorker() } @@ -943,21 +936,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) { defer st.Close(errors.New("finished serving streams for the server transport")) var wg sync.WaitGroup + streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) st.HandleStreams(func(stream *transport.Stream) { wg.Add(1) + + streamQuota.acquire() + f := func() { + defer streamQuota.release() + defer wg.Done() + s.handleStream(st, stream, s.traceInfo(st, stream)) + } + if s.opts.numServerWorkers > 0 { - data := &serverWorkerData{st: st, wg: &wg, stream: stream} select { - case s.serverWorkerChannel <- data: + case s.serverWorkerChannel <- f: return default: // If all stream workers are busy, fallback to the default code path. } } - go func() { - defer wg.Done() - s.handleStream(st, stream, s.traceInfo(st, stream)) - }() + go f() }, func(ctx context.Context, method string) context.Context { if !EnableTracing { return ctx @@ -2052,3 +2050,32 @@ func validateSendCompressor(name, clientCompressors string) error { } return fmt.Errorf("client does not support compressor %q", name) } + +// atomicSemaphore implements a blocking, counting semaphore. acquire should be +// called synchronously; release may be called asynchronously. +type atomicSemaphore struct { + n int64 + wait chan struct{} +} + +func (q *atomicSemaphore) acquire() { + if atomic.AddInt64(&q.n, -1) < 0 { + // We ran out of quota. Block until a release happens. + <-q.wait + } +} + +func (q *atomicSemaphore) release() { + // N.B. the "<= 0" check below should allow for this to work with multiple + // concurrent calls to acquire, but also note that with synchronous calls to + // acquire, as our system does, n will never be less than -1. There are + // fairness issues (queuing) to consider if this was to be generalized. + if atomic.AddInt64(&q.n, 1) <= 0 { + // An acquire was waiting on us. Unblock it. + q.wait <- struct{}{} + } +} + +func newHandlerQuota(n uint32) *atomicSemaphore { + return &atomicSemaphore{n: int64(n), wait: make(chan struct{}, 1)} +} diff --git a/server_ext_test.go b/server_ext_test.go new file mode 100644 index 000000000000..dab7a80be5b1 --- /dev/null +++ b/server_ext_test.go @@ -0,0 +1,110 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc_test + +import ( + "context" + "io" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" +) + +const defaultTestTimeout = 10 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server +// handlers are active at one time. +func (s) TestServer_MaxHandlers(t *testing.T) { + started := make(chan struct{}) + blockCalls := grpcsync.NewEvent() + + // This stub server does not properly respect the stream context, so it will + // not exit when the context is canceled. + ss := stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + started <- struct{}{} + <-blockCalls.Done() + return nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Start one RPC to the server. + ctx1, cancel1 := context.WithCancel(ctx) + _, err := ss.Client.FullDuplexCall(ctx1) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // Wait for the handler to be invoked. + select { + case <-started: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for RPC to start on server.") + } + + // Cancel it on the client. The server handler will still be running. + cancel1() + + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + s, err := ss.Client.FullDuplexCall(ctx2) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // After 100ms, allow the first call to unblock. That should allow the + // second RPC to run and finish. + select { + case <-started: + blockCalls.Fire() + t.Fatalf("RPC started unexpectedly.") + case <-time.After(100 * time.Millisecond): + blockCalls.Fire() + } + + select { + case <-started: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for second RPC to start on server.") + } + if _, err := s.Recv(); err != io.EOF { + t.Fatal("Received unexpected RPC error:", err) + } +}