Skip to content

Commit

Permalink
Ban the connect attr interpolateParams for MySQL 8 Vis dbs
Browse files Browse the repository at this point in the history
There are a number of reasons why this doesn't work right now so we're
going to prevent it from happening while we discuss whether we should
invest in fixing it. The current work to fix this is in #5428
  • Loading branch information
tdeebswihart committed Feb 29, 2024
1 parent 39ec799 commit af3184d
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 34 deletions.
11 changes: 11 additions & 0 deletions common/persistence/sql/sqlplugin/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,14 @@ type (
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}
)

func (k DbKind) String() string {
switch k {
case DbKindMain:
return "main"
case DbKindVisibility:
return "visibility"
default:
return "unknown"
}
}
8 changes: 5 additions & 3 deletions common/persistence/sql/sqlplugin/mysql/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (p *plugin) CreateDB(
cfg *config.SQL,
r resolver.ServiceResolver,
) (sqlplugin.DB, error) {
conn, err := p.createDBConnection(cfg, r)
conn, err := p.createDBConnection(session.MySQLVersion5_7, dbKind, cfg, r)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +67,7 @@ func (p *plugin) CreateAdminDB(
cfg *config.SQL,
r resolver.ServiceResolver,
) (sqlplugin.AdminDB, error) {
conn, err := p.createDBConnection(cfg, r)
conn, err := p.createDBConnection(session.MySQLVersion5_7, sqlplugin.DbKindMain, cfg, r)
if err != nil {
return nil, err
}
Expand All @@ -80,10 +80,12 @@ func (p *plugin) CreateAdminDB(
// SQL database and the object can be used to perform CRUD operations on
// the tables in the database
func (p *plugin) createDBConnection(
version session.MySQLVersion,
dbKind sqlplugin.DbKind,
cfg *config.SQL,
resolver resolver.ServiceResolver,
) (*sqlx.DB, error) {
mysqlSession, err := session.NewSession(cfg, resolver)
mysqlSession, err := session.NewSession(version, dbKind, cfg, resolver)
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions common/persistence/sql/sqlplugin/mysql/plugin_v8.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"go.temporal.io/server/common/config"
"go.temporal.io/server/common/persistence/sql"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/persistence/sql/sqlplugin/mysql/session"
"go.temporal.io/server/common/resolver"
)

Expand All @@ -52,7 +53,7 @@ func (p *pluginV8) CreateDB(
cfg *config.SQL,
r resolver.ServiceResolver,
) (sqlplugin.DB, error) {
conn, err := p.createDBConnection(cfg, r)
conn, err := p.createDBConnection(session.MySQLVersion8_0, dbKind, cfg, r)
if err != nil {
return nil, err
}
Expand All @@ -66,7 +67,7 @@ func (p *pluginV8) CreateAdminDB(
cfg *config.SQL,
r resolver.ServiceResolver,
) (sqlplugin.AdminDB, error) {
conn, err := p.createDBConnection(cfg, r)
conn, err := p.createDBConnection(session.MySQLVersion8_0, sqlplugin.DbKindMain, cfg, r)
if err != nil {
return nil, err
}
Expand Down
83 changes: 70 additions & 13 deletions common/persistence/sql/sqlplugin/mysql/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package session
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"os"
"strings"
Expand All @@ -37,9 +38,25 @@ import (

"go.temporal.io/server/common/auth"
"go.temporal.io/server/common/config"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/resolver"
)

type (
Session struct {
*sqlx.DB
}

// MySQLVersion specifies which of the distinct mysql versions we support
MySQLVersion int
)

const (
MySQLVersionUnspecified MySQLVersion = iota
MySQLVersion5_7
MySQLVersion8_0
)

const (
driverName = "mysql"

Expand All @@ -48,22 +65,40 @@ const (
defaultIsolationLevel = "'READ-COMMITTED'"
// customTLSName is the name used if a custom tls configuration is created
customTLSName = "tls-custom"

interpolateParamsAttr = "interpolateParams"
)

var dsnAttrOverrides = map[string]string{
"parseTime": "true",
"clientFoundRows": "true",
}
var (
errMySQL8VisInterpolateParamsNotSupported = errors.New("interpolateParams is not supported for mysql8 visibility stores")
dsnAttrOverrides = map[string]string{
"parseTime": "true",
"clientFoundRows": "true",
}
)

type Session struct {
*sqlx.DB
func (m MySQLVersion) String() string {
switch m {
case MySQLVersion5_7:
return "MySQL 5.7"
case MySQLVersion8_0:
return "MySQL 8.0"
default:
return "Unspecified"
}
}

func NewSession(
version MySQLVersion,
dbKind sqlplugin.DbKind,
cfg *config.SQL,
resolver resolver.ServiceResolver,
) (*Session, error) {
db, err := createConnection(cfg, resolver)
if version == MySQLVersionUnspecified {
return nil, fmt.Errorf("Bug: unspecified MySQL version provided to NewSession")
}

db, err := createConnection(version, dbKind, cfg, resolver)
if err != nil {
return nil, err
}
Expand All @@ -77,6 +112,8 @@ func (s *Session) Close() {
}

func createConnection(
version MySQLVersion,
dbKind sqlplugin.DbKind,
cfg *config.SQL,
resolver resolver.ServiceResolver,
) (*sqlx.DB, error) {
Expand All @@ -85,7 +122,12 @@ func createConnection(
return nil, err
}

db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver))
dsn, err := buildDSN(version, dbKind, cfg, resolver)
if err != nil {
return nil, err
}

db, err := sqlx.Connect(driverName, dsn)
if err != nil {
return nil, err
}
Expand All @@ -104,15 +146,24 @@ func createConnection(
return db, nil
}

func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string {
func buildDSN(
version MySQLVersion,
dbKind sqlplugin.DbKind,
cfg *config.SQL,
r resolver.ServiceResolver,
) (string, error) {
mysqlConfig := mysql.NewConfig()

mysqlConfig.User = cfg.User
mysqlConfig.Passwd = cfg.Password
mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0]
mysqlConfig.DBName = cfg.DatabaseName
mysqlConfig.Net = cfg.ConnectProtocol
mysqlConfig.Params = buildDSNAttrs(cfg)
var err error
mysqlConfig.Params, err = buildDSNAttrs(version, dbKind, cfg)
if err != nil {
return "", err
}

// https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106
// https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189
Expand All @@ -124,10 +175,10 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string {
// https://github.com/temporalio/temporal/issues/1703
mysqlConfig.RejectReadOnly = true

return mysqlConfig.FormatDSN()
return mysqlConfig.FormatDSN(), nil
}

func buildDSNAttrs(cfg *config.SQL) map[string]string {
func buildDSNAttrs(version MySQLVersion, dbKind sqlplugin.DbKind, cfg *config.SQL) (map[string]string, error) {
attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1)
for k, v := range cfg.ConnectAttributes {
k1, v1 := sanitizeAttr(k, v)
Expand All @@ -145,7 +196,13 @@ func buildDSNAttrs(cfg *config.SQL) map[string]string {
attrs[k] = v
}

return attrs
if version == MySQLVersion8_0 && dbKind == sqlplugin.DbKindVisibility {
if _, ok := attrs[interpolateParamsAttr]; ok {
return nil, errMySQL8VisInterpolateParamsNotSupported
}
}

return attrs, nil
}

func hasAttr(attrs map[string]string, key string) bool {
Expand Down
67 changes: 51 additions & 16 deletions common/persistence/sql/sqlplugin/mysql/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package session

import (
"fmt"
"net/url"
"strings"
"testing"
Expand All @@ -33,6 +34,7 @@ import (
"github.com/stretchr/testify/suite"

"go.temporal.io/server/common/config"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/resolver"
)

Expand Down Expand Up @@ -66,12 +68,15 @@ func (s *sessionTestSuite) TearDownTest() {

func (s *sessionTestSuite) TestBuildDSN() {
testCases := []struct {
in config.SQL
outURLPath string
outIsolationKey string
outIsolationVal string
name string
in config.SQL
outURLPath string
outIsolationKey string
outIsolationVal string
expectInvalidConfig bool
}{
{
name: "no connect attributes",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -84,6 +89,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "with connect attributes",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -97,6 +103,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "override isolation level (quoted, shorthand)",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -110,6 +117,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "override isolation level (unquoted, shorthand)",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -123,6 +131,7 @@ func (s *sessionTestSuite) TestBuildDSN() {
outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?",
},
{
name: "override isolation level (unquoted, full name)",
in: config.SQL{
User: "test",
Password: "pass",
Expand All @@ -137,21 +146,47 @@ func (s *sessionTestSuite) TestBuildDSN() {
},
}

for _, tc := range testCases {
r := resolver.NewMockServiceResolver(s.controller)
r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr})

out := buildDSN(&tc.in, r)
s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path")
tokens := strings.Split(out, "?")
s.Equal(2, len(tokens), "invalid url")
qry, err := url.Parse("?" + tokens[1])
s.NoError(err)
wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal)
s.Equal(wantAttrs, qry.Query(), "invalid dsn url params")
for _, version := range []MySQLVersion{MySQLVersion5_7, MySQLVersion8_0} {
for _, dbKind := range []sqlplugin.DbKind{sqlplugin.DbKindMain, sqlplugin.DbKindVisibility} {
for _, tc := range testCases {
s.Run(fmt.Sprintf("%s %s: %s", version.String(), dbKind.String(), tc.name), func() {
r := resolver.NewMockServiceResolver(s.controller)
r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr})

out, err := buildDSN(version, dbKind, &tc.in, r)
if tc.expectInvalidConfig {
s.Error(err, "Expected an invalid configuration error")
} else {
s.NoError(err)
}
s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path")
tokens := strings.Split(out, "?")
s.Equal(2, len(tokens), "invalid url")
qry, err := url.Parse("?" + tokens[1])
s.NoError(err)
wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal)
s.Equal(wantAttrs, qry.Query(), "invalid dsn url params")
})
}
}
}
}

func (s *sessionTestSuite) Test_MySQL8_Visibility_DoesntSupport_interpolateParams() {
config := config.SQL{
User: "test",
Password: "pass",
ConnectProtocol: "tcp",
ConnectAddr: "192.168.0.1:3306",
DatabaseName: "db1",
ConnectAttributes: map[string]string{"interpolateParams": "ignored"},
}
r := resolver.NewMockServiceResolver(s.controller)
r.EXPECT().Resolve(config.ConnectAddr).Return([]string{config.ConnectAddr})
_, err := buildDSN(MySQLVersion8_0, sqlplugin.DbKindVisibility, &config, r)
s.Error(err, "We should return an error when a MySQL8 Visibility database is configured with interpolateParams")
}

func buildExpectedURLParams(attrs map[string]string, isolationKey string, isolationValue string) url.Values {
result := make(map[string][]string, len(dsnAttrOverrides)+len(attrs)+1)
for k, v := range attrs {
Expand Down

0 comments on commit af3184d

Please # to comment.