Skip to content

Commit

Permalink
Repeat requests connection (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
informalict authored Jul 23, 2020
1 parent 27e32ee commit 00bb683
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 0 deletions.
33 changes: 33 additions & 0 deletions http/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ package http
import (
"context"
"encoding/base64"
"errors"
"fmt"
"sync"
"sync/atomic"

driver "github.com/arangodb/go-driver"
)

// ErrAuthenticationNotChanged is returned when authentication is not changed.
var ErrAuthenticationNotChanged = errors.New("authentication not changed")

// Authentication implements a kind of authentication.
type httpAuthentication interface {
// Prepare is called before the first request of the given connection is made.
Expand All @@ -41,6 +45,35 @@ type httpAuthentication interface {
Configure(req driver.Request) error
}

// IsAuthenticationTheSame checks whether two authentications are the same.
func IsAuthenticationTheSame(auth1, auth2 driver.Authentication) bool {

if auth1 == nil && auth2 == nil {
return true
}

if auth1 == nil || auth2 == nil {
return false
}

if auth1.Type() != auth2.Type() {
return false
}

if auth1.Type() == driver.AuthenticationTypeRaw {
if auth1.Get("value") != auth2.Get("value") {
return false
}
} else {
if auth1.Get("username") != auth2.Get("username") ||
auth1.Get("password") != auth2.Get("password") {
return false
}
}

return true
}

// newBasicAuthentication creates an authentication implementation based on the given username & password.
func newBasicAuthentication(userName, password string) httpAuthentication {
auth := fmt.Sprintf("%s:%s", userName, password)
Expand Down
74 changes: 74 additions & 0 deletions http/authentication_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//
// DISCLAIMER
//
// Copyright 2020 ArangoDB GmbH, Cologne, Germany
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright holder is ArangoDB GmbH, Cologne, Germany
//
// Author Tomasz Mielech <tomasz@arangodb.com>
//

package http

import (
"github.com/arangodb/go-driver"
"github.com/stretchr/testify/assert"
"testing"
)

func TestIsAuthenticationTheSame(t *testing.T) {
testCases := map[string]struct {
auth1 driver.Authentication
auth2 driver.Authentication
expected bool
}{
"Two authentications are nil": {
expected: true,
},
"One authentication is nil": {
auth1: driver.BasicAuthentication("", ""),
},
"Different type of authentication": {
auth1: driver.BasicAuthentication("", ""),
auth2: driver.JWTAuthentication("", ""),
},
"Raw authentications are different": {
auth1: driver.RawAuthentication(""),
auth2: driver.RawAuthentication("test"),
},
"Raw authentications are the same": {
auth1: driver.RawAuthentication("test"),
auth2: driver.RawAuthentication("test"),
expected: true,
},
"Basic authentications are different": {
auth1: driver.BasicAuthentication("test", "test"),
auth2: driver.BasicAuthentication("test", "test1"),
},
"Basic authentications are the same": {
auth1: driver.BasicAuthentication("test", "test"),
auth2: driver.BasicAuthentication("test", "test"),
expected: true,
},
}

for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
equal := IsAuthenticationTheSame(testCase.auth1, testCase.auth2)
assert.Equal(t, testCase.expected, equal)
})
}

}
76 changes: 76 additions & 0 deletions http/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"net/http/httptrace"
"net/url"
"strings"
"sync"
"time"

driver "github.com/arangodb/go-driver"
Expand Down Expand Up @@ -413,3 +414,78 @@ func (c *httpConnection) SetAuthentication(auth driver.Authentication) (driver.C
func (c *httpConnection) Protocols() driver.ProtocolSet {
return driver.ProtocolSet{driver.ProtocolHTTP}
}

// RequestRepeater creates possibility to send the request many times.
type RequestRepeater interface {
Repeat(conn driver.Connection, resp driver.Response, err error) bool
}

// RepeatConnection is responsible for sending request until request repeater gives up.
type RepeatConnection struct {
mutex sync.Mutex
auth driver.Authentication
conn driver.Connection
repeat RequestRepeater
}

func NewRepeatConnection(conn driver.Connection, repeat RequestRepeater) driver.Connection {
return &RepeatConnection{
conn: conn,
repeat: repeat,
}
}

// NewRequest creates a new request with given method and path.
func (h *RepeatConnection) NewRequest(method, path string) (driver.Request, error) {
return h.conn.NewRequest(method, path)
}

// Do performs a given request, returning its response. Repeats requests until repeat function gives up.
func (h *RepeatConnection) Do(ctx context.Context, req driver.Request) (driver.Response, error) {
for {
resp, err := h.conn.Do(ctx, req.Clone())

if !h.repeat.Repeat(h, resp, err) {
return resp, err
}
}
}

// Unmarshal unmarshals the given raw object into the given result interface.
func (h *RepeatConnection) Unmarshal(data driver.RawObject, result interface{}) error {
return h.conn.Unmarshal(data, result)
}

// Endpoints returns the endpoints used by this connection.
func (h *RepeatConnection) Endpoints() []string {
return h.conn.Endpoints()
}

// UpdateEndpoints reconfigures the connection to use the given endpoints.
func (h *RepeatConnection) UpdateEndpoints(endpoints []string) error {
return h.conn.UpdateEndpoints(endpoints)
}

// Configure the authentication used for this connection.
// Returns ErrAuthenticationNotChanged in when the authentication is not changed.
func (h *RepeatConnection) SetAuthentication(authentication driver.Authentication) (driver.Connection, error) {
h.mutex.Lock()
defer h.mutex.Unlock()

if IsAuthenticationTheSame(h.auth, authentication) {
return h, ErrAuthenticationNotChanged
}

_, err := h.conn.SetAuthentication(authentication)
if err != nil {
return nil, driver.WithStack(err)
}
h.auth = authentication

return h, nil
}

// Protocols returns all protocols used by this connection.
func (h RepeatConnection) Protocols() driver.ProtocolSet {
return h.conn.Protocols()
}
35 changes: 35 additions & 0 deletions test/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"log"
httplib "net/http"
"os"
Expand Down Expand Up @@ -423,3 +425,36 @@ func TestResponseHeader(t *testing.T) {
}
}
}

type dummyRequestRepeat struct {
counter int
}

func (r *dummyRequestRepeat) Repeat(conn driver.Connection, resp driver.Response, err error) bool {
r.counter++
if r.counter == 2 {
return false
}
return true
}

func TestCreateClientHttpRepeatConnection(t *testing.T) {
if getTestMode() != testModeSingle {
t.Skipf("Not a single")
}
createClientFromEnv(t, true)

requestRepeat := dummyRequestRepeat{}
conn := createConnectionFromEnv(t)
c, err := driver.NewClient(driver.ClientConfig{
Connection: http.NewRepeatConnection(conn, &requestRepeat),
Authentication: createAuthenticationFromEnv(t),
})

_, err = c.Connection().SetAuthentication(createAuthenticationFromEnv(t))
assert.Equal(t, http.ErrAuthenticationNotChanged, err)

_, err = c.Databases(nil)
require.NoError(t, err)
assert.Equal(t, 2, requestRepeat.counter)
}

0 comments on commit 00bb683

Please # to comment.