From 7578ae3b199aa2fff383be12916f758894e570ed Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 2 Dec 2024 12:04:53 +0300 Subject: [PATCH] client: add support for go1.23 iterators --- client/request.go | 131 ++++++++++++++++++++++++++ client/request_test.go | 203 ++++++++++++++++++++++++++++++++++++++++ client/response.go | 27 ++++++ client/response_test.go | 79 ++++++++++++++++ 4 files changed, 440 insertions(+) diff --git a/client/request.go b/client/request.go index a86d927c4e..92f1558d53 100644 --- a/client/request.go +++ b/client/request.go @@ -5,8 +5,10 @@ import ( "context" "errors" "io" + "iter" "path/filepath" "reflect" + "slices" "strconv" "sync" "time" @@ -129,6 +131,29 @@ func (r *Request) Header(key string) []string { return r.header.PeekMultiple(key) } +// Headers returns all headers in the request using an iterator. +// You can use maps.Collect() to collect all headers into a map. +// +// The returned value is valid until the request object is released. +// Any future calls to Headers method will return the modified value. Do not store references to returned value. Make copies instead. +func (r *Request) Headers() iter.Seq2[string, []string] { + return func(yield func(string, []string) bool) { + keys := r.header.PeekKeys() + + for _, key := range keys { + vals := r.header.PeekAll(utils.UnsafeString(key)) + valsStr := make([]string, len(vals)) + for i, v := range vals { + valsStr[i] = utils.UnsafeString(v) + } + + if !yield(utils.UnsafeString(key), valsStr) { + return + } + } + } +} + // AddHeader method adds a single header field and its value in the request instance. func (r *Request) AddHeader(key, val string) *Request { r.header.Add(key, val) @@ -168,6 +193,33 @@ func (r *Request) Param(key string) []string { return res } +// Params returns all params in the request using an iterator. +// You can use maps.Collect() to collect all params into a map. +// +// The returned value is valid until the request object is released. +// Any future calls to Params method will return the modified value. Do not store references to returned value. Make copies instead. +func (r *Request) Params() iter.Seq2[string, []string] { + return func(yield func(string, []string) bool) { + keys := r.params.Keys() + + for _, key := range keys { + if key == "" { + continue + } + + vals := r.params.PeekMulti(key) + valsStr := make([]string, len(vals)) + for i, v := range vals { + valsStr[i] = utils.UnsafeString(v) + } + + if !yield(key, valsStr) { + return + } + } + } +} + // AddParam method adds a single param field and its value in the request instance. func (r *Request) AddParam(key, val string) *Request { r.params.Add(key, val) @@ -254,6 +306,18 @@ func (r *Request) Cookie(key string) string { return "" } +// Cookies returns all cookies in the cookies using an iterator. +// You can use maps.Collect() to collect all cookies into a map. +func (r *Request) Cookies() iter.Seq2[string, string] { + return func(yield func(string, string) bool) { + r.cookies.VisitAll(func(key, val string) { + if !yield(key, val) { + return + } + }) + } +} + // SetCookie method sets a single cookie field and its value in the request instance. // It will override cookie which set in client instance. func (r *Request) SetCookie(key, val string) *Request { @@ -291,6 +355,18 @@ func (r *Request) PathParam(key string) string { return "" } +// PathParams returns all path params in request instance. +// You can use maps.Collect() to collect all cookies into a map. +func (r *Request) PathParams() iter.Seq2[string, string] { + return func(yield func(string, string) bool) { + r.path.VisitAll(func(key, val string) { + if !yield(key, val) { + return + } + }) + } +} + // SetPathParam method sets a single path param field and its value in the request instance. // It will override path param which set in client instance. func (r *Request) SetPathParam(key, val string) *Request { @@ -376,6 +452,33 @@ func (r *Request) FormData(key string) []string { return res } +// FormDatas method returns all form datas in request instance. +// You can use maps.Collect() to collect all cookies into a map. +// +// The returned value is valid until the request object is released. +// Any future calls to FormDatas method will return the modified value. Do not store references to returned value. Make copies instead. +func (r *Request) FormDatas() iter.Seq2[string, []string] { + return func(yield func(string, []string) bool) { + keys := r.formData.Keys() + + for _, key := range keys { + if key == "" { + continue + } + + vals := r.formData.PeekMulti(key) + valsStr := make([]string, len(vals)) + for i, v := range vals { + valsStr[i] = utils.UnsafeString(v) + } + + if !yield(key, valsStr) { + return + } + } + } +} + // AddFormData method adds a single form data field and its value in the request instance. func (r *Request) AddFormData(key, val string) *Request { r.formData.AddData(key, val) @@ -435,6 +538,14 @@ func (r *Request) File(name string) *File { return nil } +// Files method returns all files in request instance. +// +// The returned value is valid until the request object is released. +// Any future calls to Files method will return the modified value. Do not store references to returned value. Make copies instead. +func (r *Request) Files() []*File { + return r.files +} + // FileByPath returns file ptr store in request obj by path. func (r *Request) FileByPath(path string) *File { for _, v := range r.files { @@ -617,6 +728,16 @@ type QueryParam struct { *fasthttp.Args } +// Keys method returns all keys in the query params. +func (f *QueryParam) Keys() []string { + keys := make([]string, f.Len()) + f.VisitAll(func(key, value []byte) { + keys = append(keys, utils.UnsafeString(key)) + }) + + return slices.Compact(keys) +} + // AddParams receive a map and add each value to param. func (p *QueryParam) AddParams(r map[string][]string) { for k, v := range r { @@ -747,6 +868,16 @@ type FormData struct { *fasthttp.Args } +// Keys method returns all keys in the form data. +func (f *FormData) Keys() []string { + keys := make([]string, f.Len()) + f.VisitAll(func(key, value []byte) { + keys = append(keys, utils.UnsafeString(key)) + }) + + return slices.Compact(keys) +} + // AddData method is a wrapper of Args's Add method. func (f *FormData) AddData(key, val string) { f.Add(key, val) diff --git a/client/request_test.go b/client/request_test.go index f62865a342..0b5c866395 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "io" + "maps" "mime/multipart" "net" "os" @@ -157,6 +158,40 @@ func Test_Request_Header(t *testing.T) { }) } +func Test_Request_Headers(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.AddHeaders(map[string][]string{ + "foo": {"bar", "fiber"}, + "bar": {"foo"}, + }) + + headers := maps.Collect(req.Headers()) + + require.Contains(t, headers["Foo"], "fiber") + require.Contains(t, headers["Foo"], "bar") + require.Contains(t, headers["Bar"], "foo") + + require.Len(t, headers, 2) +} + +func Benchmark_Request_Headers(b *testing.B) { + req := AcquireRequest() + req.AddHeaders(map[string][]string{ + "foo": {"bar", "fiber"}, + "bar": {"foo"}, + }) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for range req.Headers() { + } + } +} + func Test_Request_QueryParam(t *testing.T) { t.Parallel() @@ -282,6 +317,40 @@ func Test_Request_QueryParam(t *testing.T) { }) } +func Test_Request_Params(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.AddParams(map[string][]string{ + "foo": {"bar", "fiber"}, + "bar": {"foo"}, + }) + + pathParams := maps.Collect(req.Params()) + + require.Contains(t, pathParams["foo"], "bar") + require.Contains(t, pathParams["foo"], "fiber") + require.Contains(t, pathParams["bar"], "foo") + + require.Len(t, pathParams, 2) +} + +func Benchmark_Request_Params(b *testing.B) { + req := AcquireRequest() + req.AddParams(map[string][]string{ + "foo": {"bar", "fiber"}, + "bar": {"foo"}, + }) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for range req.Params() { + } + } +} + func Test_Request_UA(t *testing.T) { t.Parallel() @@ -364,6 +433,39 @@ func Test_Request_Cookie(t *testing.T) { }) } +func Test_Request_Cookies(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + + cookies := maps.Collect(req.Cookies()) + + require.Equal(t, "bar", cookies["foo"]) + require.Equal(t, "foo", cookies["bar"]) + + require.Len(t, cookies, 2) +} + +func Benchmark_Request_Cookies(b *testing.B) { + req := AcquireRequest() + req.SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for range req.Cookies() { + } + } +} + func Test_Request_PathParam(t *testing.T) { t.Parallel() @@ -441,6 +543,39 @@ func Test_Request_PathParam(t *testing.T) { }) } +func Test_Request_PathParams(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + + pathParams := maps.Collect(req.PathParams()) + + require.Equal(t, "bar", pathParams["foo"]) + require.Equal(t, "foo", pathParams["bar"]) + + require.Len(t, pathParams, 2) +} + +func Benchmark_Request_PathParams(b *testing.B) { + req := AcquireRequest() + req.SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for range req.PathParams() { + } + } +} + func Test_Request_FormData(t *testing.T) { t.Parallel() @@ -610,6 +745,40 @@ func Test_Request_File(t *testing.T) { }) } +func Test_Request_Files(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.AddFile("../.github/index.html") + req.AddFiles(AcquireFile(SetFileName("tmp.txt"))) + + files := req.Files() + + require.Equal(t, "../.github/index.html", files[0].path) + require.Nil(t, files[0].reader) + + require.Equal(t, "tmp.txt", files[1].name) + require.Nil(t, files[1].reader) + + require.Len(t, files, 2) +} + +func Benchmark_Request_Files(b *testing.B) { + req := AcquireRequest() + req.AddFile("../.github/index.html") + req.AddFiles(AcquireFile(SetFileName("tmp.txt"))) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for k, v := range req.Files() { + _ = k + _ = v + } + } +} + func Test_Request_Timeout(t *testing.T) { t.Parallel() @@ -1181,6 +1350,40 @@ func Test_Request_Body_With_Server(t *testing.T) { }) } +func Test_Request_FormDatas(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.AddFormDatas(map[string][]string{ + "foo": {"bar", "fiber"}, + "bar": {"foo"}, + }) + + pathParams := maps.Collect(req.FormDatas()) + + require.Contains(t, pathParams["foo"], "bar") + require.Contains(t, pathParams["foo"], "fiber") + require.Contains(t, pathParams["bar"], "foo") + + require.Len(t, pathParams, 2) +} + +func Benchmark_Request_FormDatas(b *testing.B) { + req := AcquireRequest() + req.AddFormDatas(map[string][]string{ + "foo": {"bar", "fiber"}, + "bar": {"foo"}, + }) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for range req.FormDatas() { + } + } +} + func Test_Request_Error_Body_With_Server(t *testing.T) { t.Parallel() t.Run("json error", func(t *testing.T) { diff --git a/client/response.go b/client/response.go index e60c6bd0fb..8d21329774 100644 --- a/client/response.go +++ b/client/response.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/fs" + "iter" "os" "path/filepath" "sync" @@ -55,7 +56,33 @@ func (r *Response) Header(key string) string { return utils.UnsafeString(r.RawResponse.Header.Peek(key)) } +// Headers returns all headers in the response using an iterator. +// You can use maps.Collect() to collect all headers into a map. +// +// The returned value is valid until the response object is released. +// Any future calls to Headers method will return the modified value. Do not store references to returned value. Make copies instead. +func (r *Response) Headers() iter.Seq2[string, []string] { + return func(yield func(string, []string) bool) { + keys := r.RawResponse.Header.PeekKeys() + + for _, key := range keys { + vals := r.RawResponse.Header.PeekAll(utils.UnsafeString(key)) + valsStr := make([]string, len(vals)) + for i, v := range vals { + valsStr[i] = utils.UnsafeString(v) + } + + if !yield(utils.UnsafeString(key), valsStr) { + return + } + } + } +} + // Cookies method to access all the response cookies. +// +// The returned value is valid until the response object is released. +// Any future calls to Cookies method will return the modified value. Do not store references to returned value. Make copies instead. func (r *Response) Cookies() []*fasthttp.Cookie { return r.cookie } diff --git a/client/response_test.go b/client/response_test.go index bf12e75161..340f1d8a2f 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "encoding/xml" + "fmt" "io" "net" "os" @@ -199,6 +200,84 @@ func Test_Response_Header(t *testing.T) { resp.Close() } +func Test_Response_Headers(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Add("foo", "bar") + c.Response().Header.Add("foo", "bar2") + c.Response().Header.Add("foo2", "bar") + + return c.SendString("helo world") + }) + }) + defer server.stop() + + client := New().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + + headers := make(map[string][]string) + for k, v := range resp.Headers() { + headers[k] = make([]string, 0) + for _, value := range v { + fmt.Print(string(value)) + headers[k] = append(headers[k], string(value)) + } + } + + require.Contains(t, headers["Foo"], "bar") + require.Contains(t, headers["Foo"], "bar2") + require.Contains(t, headers["Foo2"], "bar") + + resp.Close() +} + +func Benchmark_Headers(b *testing.B) { + server := startTestServer( + b, + func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Add("foo", "bar") + c.Response().Header.Add("foo", "bar2") + c.Response().Header.Add("foo", "bar3") + + c.Response().Header.Add("foo2", "bar") + c.Response().Header.Add("foo2", "bar2") + c.Response().Header.Add("foo2", "bar3") + + return c.SendString("helo world") + }) + }, + ) + + defer server.stop() + + client := New().SetDial(server.dial()) + + b.ResetTimer() + b.ReportAllocs() + + var err error + var resp *Response + for i := 0; i < b.N; i++ { + resp, err = AcquireRequest(). + SetClient(client). + Get("http://example.com") + + for range resp.Headers() { + } + + resp.Close() + } + require.NoError(b, err) +} + func Test_Response_Cookie(t *testing.T) { t.Parallel()