From bcf9171a20e44ed81a6eb152e3ca9e35b2c02c5d Mon Sep 17 00:00:00 2001 From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com> Date: Mon, 23 Sep 2024 21:39:46 +0530 Subject: [PATCH] transport: Fix reporting of bytes read while reading headers (#7660) --- internal/transport/transport.go | 2 +- internal/transport/transport_test.go | 31 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 9a3bb3a63eae..e12cb0bc914b 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -625,7 +625,7 @@ func (t *transportReader) ReadHeader(header []byte) (int, error) { t.er = err return 0, err } - t.windowHandler(len(header)) + t.windowHandler(n) return n, nil } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 05f0b0b2e35f..65efb30c4bb6 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -2845,3 +2845,34 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { isGreetingDone.Store(true) ct.Close(errors.New("manually closed by client")) } + +// TestReadHeaderMultipleBuffers tests the stream when the gRPC headers are +// split across multiple buffers. It verifies that the reporting of the +// number of bytes read for flow control is correct. +func (s) TestReadHeaderMultipleBuffers(t *testing.T) { + headerLen := 5 + recvBuffer := newRecvBuffer() + recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, 3)}) + recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, headerLen-3)}) + bytesRead := 0 + s := Stream{ + requestRead: func(int) {}, + trReader: &transportReader{ + reader: &recvBufferReader{ + recv: recvBuffer, + }, + windowHandler: func(i int) { + bytesRead += i + }, + }, + } + + header := make([]byte, headerLen) + err := s.ReadHeader(header) + if err != nil { + t.Fatalf("ReadHeader(%v) = %v", header, err) + } + if bytesRead != headerLen { + t.Errorf("bytesRead = %d, want = %d", bytesRead, headerLen) + } +}