From 9a8ae0dff5d9806b75ab8b7887500382494ce059 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Wed, 1 Jan 2025 13:14:49 +0200 Subject: [PATCH] cmd/atlas/internal/cloudapi: expose http error to callers --- cmd/atlas/internal/cloudapi/client.go | 24 ++++++++-- cmd/atlas/internal/cloudapi/client_test.go | 55 +++++++++++++++++++++- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/cmd/atlas/internal/cloudapi/client.go b/cmd/atlas/internal/cloudapi/client.go index a75aa1c2165..c3d8716946e 100644 --- a/cmd/atlas/internal/cloudapi/client.go +++ b/cmd/atlas/internal/cloudapi/client.go @@ -51,6 +51,9 @@ func New(endpoint, token string) *Client { transport = client.HTTPClient.Transport ) client.HTTPClient.Timeout = time.Second * 30 + client.ErrorHandler = func(res *http.Response, err error, _ int) (*http.Response, error) { + return res, err // Let Client.post handle the error. + } client.HTTPClient.Transport = &roundTripper{ token: token, base: transport, @@ -284,16 +287,21 @@ func (c *Client) post(ctx context.Context, query string, vars, data any) error { if err != nil { return err } - defer req.Body.Close() + defer res.Body.Close() switch { case res.StatusCode == http.StatusUnauthorized: return ErrUnauthorized case res.StatusCode != http.StatusOK: + buf, err := io.ReadAll(io.LimitReader(res.Body, 1<<20)) + if err != nil { + return &HTTPError{StatusCode: res.StatusCode, Message: err.Error()} + } var v struct { Errors errlist `json:"errors,omitempty"` } - if err := json.NewDecoder(res.Body).Decode(&v); err != nil || len(v.Errors) == 0 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) + if err := json.Unmarshal(buf, &v); err != nil || len(v.Errors) == 0 { + // If the error is not a GraphQL error, return the message as is. + return &HTTPError{StatusCode: res.StatusCode, Message: string(bytes.TrimSpace(buf))} } return v.Errors } @@ -347,6 +355,16 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.base.RoundTrip(req) } +// HTTPError represents a generic HTTP error. Hence, non 2xx status codes. +type HTTPError struct { + StatusCode int + Message string +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("unexpected error code %d: %s", e.StatusCode, e.Message) +} + // RedactedURL returns a URL string with the userinfo redacted. func RedactedURL(s string) (string, error) { u, err := sqlclient.ParseURL(s) diff --git a/cmd/atlas/internal/cloudapi/client_test.go b/cmd/atlas/internal/cloudapi/client_test.go index 8b54abfd5ef..d962e35d089 100644 --- a/cmd/atlas/internal/cloudapi/client_test.go +++ b/cmd/atlas/internal/cloudapi/client_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "runtime" + "strings" "testing" "ariga.io/atlas/sql/migrate" @@ -53,7 +54,7 @@ func TestClient_Dir(t *testing.T) { require.Equal(t, dcheck.Sum(), gcheck.Sum()) } -func TestClient_Error(t *testing.T) { +func TestClient_GraphQLError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnprocessableEntity) _, err := w.Write([]byte(`{"errors":[{"message":"error\n","path":["variable","input","driver"],"extensions":{}}],"data":null}`)) @@ -69,6 +70,58 @@ func TestClient_Error(t *testing.T) { require.Empty(t, link) } +func TestClient_HTTPError(t *testing.T) { + var ( + body string + code = http.StatusInternalServerError + ) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, body, code) + })) + client := New(srv.URL, "atlas") + defer srv.Close() + body = "internal error" + _, err := client.ReportMigration(context.Background(), ReportMigrationInput{ + EnvName: "foo", + ProjectName: "bar", + }) + require.EqualError(t, err, `unexpected error code 500: internal error`) + + // Error should be limited to 1MB. + body = fmt.Sprintf("%s!", strings.Repeat("a", 1<<20)) + _, err = client.ReportMigration(context.Background(), ReportMigrationInput{ + EnvName: "foo", + ProjectName: "bar", + }) + require.ErrorContains(t, err, "unexpected error code 500: a") + require.NotContains(t, err.Error(), "!") + + // Unauthorized error. + body = "unauthorized" + code = http.StatusUnauthorized + _, err = client.ReportMigration(context.Background(), ReportMigrationInput{ + EnvName: "foo", + ProjectName: "bar", + }) + require.ErrorIs(t, err, ErrUnauthorized) + + code = http.StatusForbidden + body = "Forbidden" + _, err = client.ReportMigration(context.Background(), ReportMigrationInput{ + EnvName: "foo", + ProjectName: "bar", + }) + require.EqualError(t, err, "unexpected error code 403: Forbidden") + + code = http.StatusConflict + body = `{"errors":[{"message":"conflict\n","path":["variable","input","driver"],"extensions":{}}],"data":null}` + _, err = client.ReportMigration(context.Background(), ReportMigrationInput{ + EnvName: "foo", + ProjectName: "bar", + }) + require.EqualError(t, err, "variable.input.driver conflict", "GraphQL error") +} + func TestClient_ReportMigration(t *testing.T) { const project, env = "atlas", "dev" srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {