diff --git a/client.go b/client.go index 384b99a..8047af4 100644 --- a/client.go +++ b/client.go @@ -167,14 +167,20 @@ func (cli *Client) parseResponse(resp *http.Response) (*Response, error) { resp.Request.Method, resp.Request.URL.String()) } - // Prepare gzip reader for uncompressing gzipped JSON response - ungzipper, err := gzip.NewReader(resp.Body) - if err != nil { - return nil, err + var reader io.ReadCloser + if resp.Header.Get("Content-Encoding") == "gzip" { + // Prepare gzip reader for uncompressing gzipped JSON response + ungzipper, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer ungzipper.Close() + reader = ungzipper + } else { + reader = resp.Body } - defer ungzipper.Close() - if err := json.NewDecoder(ungzipper).Decode(apiresp); err != nil { + if err := json.NewDecoder(reader).Decode(apiresp); err != nil { return nil, err } diff --git a/vt_test.go b/vt_test.go index 918f4ac..c10d04a 100644 --- a/vt_test.go +++ b/vt_test.go @@ -3,6 +3,7 @@ package vt import ( "compress/gzip" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" @@ -30,6 +31,7 @@ type TestServer struct { expectedMethod string response interface{} expectedBody string + status int expectedHeaders map[string]string } @@ -49,6 +51,11 @@ func (ts *TestServer) SetResponse(r interface{}) *TestServer { return ts } +func (ts *TestServer) SetStatusCode(s int) *TestServer { + ts.status = s + return ts +} + func (ts *TestServer) SetExpectedBody(body string) *TestServer { ts.expectedBody = body return ts @@ -94,9 +101,17 @@ func (ts *TestServer) handler(w http.ResponseWriter, r *http.Request) { return } w.Header().Set("Content-Type", "application/json") - gw := gzip.NewWriter(w) - gw.Write(js) - gw.Close() + if ts.status != 0 { + w.WriteHeader(ts.status) + } + if ts.status != 429 { + w.Header().Set("content-encoding", "gzip") + gw := gzip.NewWriter(w) + gw.Write(js) + gw.Close() + } else { + w.Write(js) + } } // This tests GET request with passing in a parameter. @@ -429,3 +444,27 @@ func TestRequestHeadersOverrideGlobalHeaders(t *testing.T) { err := c.PostObject(URL("/collection"), o) assert.NoError(t, err) } + +func TestGetObjectOutOfQuota(t *testing.T) { + ts := NewTestServer(t). + SetExpectedMethod("GET"). + SetStatusCode(429). + SetResponse(map[string]interface{}{ + "error": map[string]interface{}{ + "code": "QuotaExceededError", + "message": "Quota exceeded", + }, + }) + + defer ts.Close() + + SetHost(ts.URL) + c := NewClient("apikey") + _, err := c.GetObject(URL("files/abcabcabcabcabc")) + if err != nil { + var vtErr *Error + if !errors.As(err, &vtErr) && err.(Error).Code != "QuotaExceededError" { + t.Fatalf("Error getting object from VT: %s", err) + } + } +}