-
Notifications
You must be signed in to change notification settings - Fork 0
/
connection.go
107 lines (93 loc) · 2.91 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package sqlcommenter
import (
"context"
"database/sql/driver"
)
var (
_ driver.Pinger = (*connection)(nil)
_ driver.Execer = (*connection)(nil) // nolint:staticcheck
_ driver.ExecerContext = (*connection)(nil)
_ driver.Queryer = (*connection)(nil) // nolint:staticcheck
_ driver.QueryerContext = (*connection)(nil)
_ driver.Conn = (*connection)(nil)
_ driver.ConnPrepareContext = (*connection)(nil)
_ driver.ConnBeginTx = (*connection)(nil)
_ driver.SessionResetter = (*connection)(nil)
_ driver.NamedValueChecker = (*connection)(nil)
)
func newConn(conn driver.Conn, cmt *commenter) *connection {
return &connection{
Conn: conn,
cmt: cmt,
}
}
type connection struct {
driver.Conn
cmt *commenter
}
func (c *connection) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
preparer, ok := c.Conn.(driver.ConnPrepareContext)
if !ok {
return nil, driver.ErrSkip
}
return preparer.PrepareContext(ctx, query)
}
func (c *connection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
beginTx, ok := c.Conn.(driver.ConnBeginTx)
if !ok {
return nil, driver.ErrSkip
}
return beginTx.BeginTx(ctx, opts)
}
func (c *connection) Query(query string, args []driver.Value) (driver.Rows, error) {
queryer, ok := c.Conn.(driver.Queryer) // nolint:staticcheck
if !ok {
return nil, driver.ErrSkip
}
return queryer.Query(c.withComment(context.Background(), query), args)
}
func (c *connection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
queryer, ok := c.Conn.(driver.QueryerContext)
if !ok {
return nil, driver.ErrSkip
}
return queryer.QueryContext(ctx, c.withComment(ctx, query), args)
}
func (c *connection) Exec(query string, args []driver.Value) (driver.Result, error) {
execer, ok := c.Conn.(driver.Execer) // nolint:staticcheck
if !ok {
return nil, driver.ErrSkip
}
return execer.Exec(c.withComment(context.Background(), query), args)
}
func (c *connection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
execer, ok := c.Conn.(driver.ExecerContext)
if !ok {
return nil, driver.ErrSkip
}
return execer.ExecContext(ctx, c.withComment(ctx, query), args)
}
func (c *connection) Ping(ctx context.Context) error {
pinger, ok := c.Conn.(driver.Pinger)
if !ok {
return driver.ErrSkip
}
return pinger.Ping(ctx)
}
func (c *connection) CheckNamedValue(value *driver.NamedValue) error {
checker, ok := c.Conn.(driver.NamedValueChecker)
if !ok {
return driver.ErrSkip
}
return checker.CheckNamedValue(value)
}
func (c *connection) ResetSession(ctx context.Context) error {
resetter, ok := c.Conn.(driver.SessionResetter)
if !ok {
return driver.ErrSkip
}
return resetter.ResetSession(ctx)
}
func (c *connection) withComment(ctx context.Context, query string) string {
return c.cmt.comment(ctx, query)
}