Skip to content

Commit

Permalink
Add support for drivers which do not implement driver.DriverContext.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbub committed Nov 26, 2021
1 parent 3a70f09 commit 0f9a9c4
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v0.2.0

- Add support for drivers which do not implement `driver.DriverContext`.

## v0.1.1

- Optimize `Attrs` encoding by reducing allocations.
Expand Down
22 changes: 20 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ func (d *commentDriver) Open(name string) (driver.Conn, error) {
}

func (d *commentDriver) OpenConnector(name string) (driver.Connector, error) {
ctr, err := d.drv.(driver.DriverContext).OpenConnector(name)
drvCtx, ok := d.drv.(driver.DriverContext)
if !ok {
return &dsnConnector{dsn: name, drv: d}, nil
}

ctr, err := drvCtx.OpenConnector(name)
if err != nil {
return nil, err
}
return newConnector(ctr, d), err
return newConnector(ctr, d), nil
}

func newConnector(ctr driver.Connector, drv *commentDriver) *connector {
Expand All @@ -56,3 +61,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
func (c *connector) Driver() driver.Driver {
return c.drv
}

type dsnConnector struct {
dsn string
drv *commentDriver
}

func (c *dsnConnector) Connect(context.Context) (driver.Conn, error) {
return c.drv.Open(c.dsn)
}

func (c *dsnConnector) Driver() driver.Driver {
return c.drv
}
84 changes: 56 additions & 28 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,51 @@ func TestWrapDriver(t *testing.T) {
},
}

drivers := []struct {
name string
newDriver func(conn *mockConn) driver.Driver
}{
{
name: "driver",
newDriver: func(conn *mockConn) driver.Driver {
return &mockDriver{conn: conn}
},
},
{
name: "driverctx",
newDriver: func(conn *mockConn) driver.Driver {
return &mockDriverContext{conn: conn}
},
},
}

for i, cs := range cases {
t.Run(cs.name, func(t *testing.T) {
var ctx context.Context
if cs.makeCtx != nil {
ctx = cs.makeCtx()
} else {
ctx = context.Background()
}

conn := &mockConn{}
orig := &mockDriver{conn: conn}
drv := WrapDriver(orig, cs.options...)

driverName := fmt.Sprintf("driver-%v", i)
sql.Register(driverName, drv)

db, err := sql.Open(driverName, "")
if err != nil {
t.Fatal(err)
}
defer db.Close()

cs.perform(ctx, db)
cs.assert(t, conn)
})
for j, drv := range drivers {
t.Run(cs.name+" "+drv.name, func(t *testing.T) {
var ctx context.Context
if cs.makeCtx != nil {
ctx = cs.makeCtx()
} else {
ctx = context.Background()
}

conn := &mockConn{}
orig := drv.newDriver(conn)
drv := WrapDriver(orig, cs.options...)

driverName := fmt.Sprintf("driver-%v-%v", i, j)
sql.Register(driverName, drv)

db, err := sql.Open(driverName, "")
if err != nil {
t.Fatal(err)
}
defer db.Close()

cs.perform(ctx, db)
cs.assert(t, conn)
})
}
}
}

Expand All @@ -126,15 +146,15 @@ func userKeyFromContext(ctx context.Context) string {
return ctx.Value(contextUserKey).(string)
}

type mockDriver struct {
type mockDriverContext struct {
conn *mockConn
}

func (m *mockDriver) Open(name string) (driver.Conn, error) {
func (m *mockDriverContext) Open(name string) (driver.Conn, error) {
return m.conn, nil
}

func (m *mockDriver) OpenConnector(name string) (driver.Connector, error) {
func (m *mockDriverContext) OpenConnector(name string) (driver.Connector, error) {
return &mockConnector{
drv: m,
conn: m.conn,
Expand Down Expand Up @@ -178,7 +198,7 @@ func (m *mockConn) assertExecContext(t *testing.T, query string) {
}

type mockConnector struct {
drv *mockDriver
drv *mockDriverContext
conn *mockConn
}

Expand All @@ -189,3 +209,11 @@ func (m *mockConnector) Connect(ctx context.Context) (driver.Conn, error) {
func (m *mockConnector) Driver() driver.Driver {
return m.drv
}

type mockDriver struct {
conn *mockConn
}

func (m *mockDriver) Open(name string) (driver.Conn, error) {
return m.conn, nil
}

0 comments on commit 0f9a9c4

Please # to comment.