From af3184d4ac5cea5e50d29a704d1b7a6e3cb2e22f Mon Sep 17 00:00:00 2001 From: Tim Deeb-Swihart Date: Tue, 20 Feb 2024 11:16:03 -0800 Subject: [PATCH] Ban the connect attr interpolateParams for MySQL 8 Vis dbs There are a number of reasons why this doesn't work right now so we're going to prevent it from happening while we discuss whether we should invest in fixing it. The current work to fix this is in https://github.com/temporalio/temporal/pull/5428 --- .../persistence/sql/sqlplugin/interfaces.go | 11 +++ .../persistence/sql/sqlplugin/mysql/plugin.go | 8 +- .../sql/sqlplugin/mysql/plugin_v8.go | 5 +- .../sql/sqlplugin/mysql/session/session.go | 83 ++++++++++++++++--- .../sqlplugin/mysql/session/session_test.go | 67 +++++++++++---- 5 files changed, 140 insertions(+), 34 deletions(-) diff --git a/common/persistence/sql/sqlplugin/interfaces.go b/common/persistence/sql/sqlplugin/interfaces.go index 649fbb7f9ac3..b702dd3f58cd 100644 --- a/common/persistence/sql/sqlplugin/interfaces.go +++ b/common/persistence/sql/sqlplugin/interfaces.go @@ -137,3 +137,14 @@ type ( PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) } ) + +func (k DbKind) String() string { + switch k { + case DbKindMain: + return "main" + case DbKindVisibility: + return "visibility" + default: + return "unknown" + } +} diff --git a/common/persistence/sql/sqlplugin/mysql/plugin.go b/common/persistence/sql/sqlplugin/mysql/plugin.go index b93b2fe42b47..40262eeef7e9 100644 --- a/common/persistence/sql/sqlplugin/mysql/plugin.go +++ b/common/persistence/sql/sqlplugin/mysql/plugin.go @@ -53,7 +53,7 @@ func (p *plugin) CreateDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.DB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion5_7, dbKind, cfg, r) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (p *plugin) CreateAdminDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.AdminDB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion5_7, sqlplugin.DbKindMain, cfg, r) if err != nil { return nil, err } @@ -80,10 +80,12 @@ func (p *plugin) CreateAdminDB( // SQL database and the object can be used to perform CRUD operations on // the tables in the database func (p *plugin) createDBConnection( + version session.MySQLVersion, + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*sqlx.DB, error) { - mysqlSession, err := session.NewSession(cfg, resolver) + mysqlSession, err := session.NewSession(version, dbKind, cfg, resolver) if err != nil { return nil, err } diff --git a/common/persistence/sql/sqlplugin/mysql/plugin_v8.go b/common/persistence/sql/sqlplugin/mysql/plugin_v8.go index 6146fbb864ea..f89fed218edf 100644 --- a/common/persistence/sql/sqlplugin/mysql/plugin_v8.go +++ b/common/persistence/sql/sqlplugin/mysql/plugin_v8.go @@ -28,6 +28,7 @@ import ( "go.temporal.io/server/common/config" "go.temporal.io/server/common/persistence/sql" "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql/session" "go.temporal.io/server/common/resolver" ) @@ -52,7 +53,7 @@ func (p *pluginV8) CreateDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.DB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion8_0, dbKind, cfg, r) if err != nil { return nil, err } @@ -66,7 +67,7 @@ func (p *pluginV8) CreateAdminDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.AdminDB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion8_0, sqlplugin.DbKindMain, cfg, r) if err != nil { return nil, err } diff --git a/common/persistence/sql/sqlplugin/mysql/session/session.go b/common/persistence/sql/sqlplugin/mysql/session/session.go index b9490d5043cf..58a8f33a5169 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session.go @@ -27,6 +27,7 @@ package session import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "os" "strings" @@ -37,9 +38,25 @@ import ( "go.temporal.io/server/common/auth" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) +type ( + Session struct { + *sqlx.DB + } + + // MySQLVersion specifies which of the distinct mysql versions we support + MySQLVersion int +) + +const ( + MySQLVersionUnspecified MySQLVersion = iota + MySQLVersion5_7 + MySQLVersion8_0 +) + const ( driverName = "mysql" @@ -48,22 +65,40 @@ const ( defaultIsolationLevel = "'READ-COMMITTED'" // customTLSName is the name used if a custom tls configuration is created customTLSName = "tls-custom" + + interpolateParamsAttr = "interpolateParams" ) -var dsnAttrOverrides = map[string]string{ - "parseTime": "true", - "clientFoundRows": "true", -} +var ( + errMySQL8VisInterpolateParamsNotSupported = errors.New("interpolateParams is not supported for mysql8 visibility stores") + dsnAttrOverrides = map[string]string{ + "parseTime": "true", + "clientFoundRows": "true", + } +) -type Session struct { - *sqlx.DB +func (m MySQLVersion) String() string { + switch m { + case MySQLVersion5_7: + return "MySQL 5.7" + case MySQLVersion8_0: + return "MySQL 8.0" + default: + return "Unspecified" + } } func NewSession( + version MySQLVersion, + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*Session, error) { - db, err := createConnection(cfg, resolver) + if version == MySQLVersionUnspecified { + return nil, fmt.Errorf("Bug: unspecified MySQL version provided to NewSession") + } + + db, err := createConnection(version, dbKind, cfg, resolver) if err != nil { return nil, err } @@ -77,6 +112,8 @@ func (s *Session) Close() { } func createConnection( + version MySQLVersion, + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*sqlx.DB, error) { @@ -85,7 +122,12 @@ func createConnection( return nil, err } - db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver)) + dsn, err := buildDSN(version, dbKind, cfg, resolver) + if err != nil { + return nil, err + } + + db, err := sqlx.Connect(driverName, dsn) if err != nil { return nil, err } @@ -104,7 +146,12 @@ func createConnection( return db, nil } -func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { +func buildDSN( + version MySQLVersion, + dbKind sqlplugin.DbKind, + cfg *config.SQL, + r resolver.ServiceResolver, +) (string, error) { mysqlConfig := mysql.NewConfig() mysqlConfig.User = cfg.User @@ -112,7 +159,11 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0] mysqlConfig.DBName = cfg.DatabaseName mysqlConfig.Net = cfg.ConnectProtocol - mysqlConfig.Params = buildDSNAttrs(cfg) + var err error + mysqlConfig.Params, err = buildDSNAttrs(version, dbKind, cfg) + if err != nil { + return "", err + } // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106 // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189 @@ -124,10 +175,10 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { // https://github.com/temporalio/temporal/issues/1703 mysqlConfig.RejectReadOnly = true - return mysqlConfig.FormatDSN() + return mysqlConfig.FormatDSN(), nil } -func buildDSNAttrs(cfg *config.SQL) map[string]string { +func buildDSNAttrs(version MySQLVersion, dbKind sqlplugin.DbKind, cfg *config.SQL) (map[string]string, error) { attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1) for k, v := range cfg.ConnectAttributes { k1, v1 := sanitizeAttr(k, v) @@ -145,7 +196,13 @@ func buildDSNAttrs(cfg *config.SQL) map[string]string { attrs[k] = v } - return attrs + if version == MySQLVersion8_0 && dbKind == sqlplugin.DbKindVisibility { + if _, ok := attrs[interpolateParamsAttr]; ok { + return nil, errMySQL8VisInterpolateParamsNotSupported + } + } + + return attrs, nil } func hasAttr(attrs map[string]string, key string) bool { diff --git a/common/persistence/sql/sqlplugin/mysql/session/session_test.go b/common/persistence/sql/sqlplugin/mysql/session/session_test.go index 6343daecba8e..8cb6bd56c05f 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session_test.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session_test.go @@ -25,6 +25,7 @@ package session import ( + "fmt" "net/url" "strings" "testing" @@ -33,6 +34,7 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) @@ -66,12 +68,15 @@ func (s *sessionTestSuite) TearDownTest() { func (s *sessionTestSuite) TestBuildDSN() { testCases := []struct { - in config.SQL - outURLPath string - outIsolationKey string - outIsolationVal string + name string + in config.SQL + outURLPath string + outIsolationKey string + outIsolationVal string + expectInvalidConfig bool }{ { + name: "no connect attributes", in: config.SQL{ User: "test", Password: "pass", @@ -84,6 +89,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "with connect attributes", in: config.SQL{ User: "test", Password: "pass", @@ -97,6 +103,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (quoted, shorthand)", in: config.SQL{ User: "test", Password: "pass", @@ -110,6 +117,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (unquoted, shorthand)", in: config.SQL{ User: "test", Password: "pass", @@ -123,6 +131,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (unquoted, full name)", in: config.SQL{ User: "test", Password: "pass", @@ -137,21 +146,47 @@ func (s *sessionTestSuite) TestBuildDSN() { }, } - for _, tc := range testCases { - r := resolver.NewMockServiceResolver(s.controller) - r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr}) - - out := buildDSN(&tc.in, r) - s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path") - tokens := strings.Split(out, "?") - s.Equal(2, len(tokens), "invalid url") - qry, err := url.Parse("?" + tokens[1]) - s.NoError(err) - wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal) - s.Equal(wantAttrs, qry.Query(), "invalid dsn url params") + for _, version := range []MySQLVersion{MySQLVersion5_7, MySQLVersion8_0} { + for _, dbKind := range []sqlplugin.DbKind{sqlplugin.DbKindMain, sqlplugin.DbKindVisibility} { + for _, tc := range testCases { + s.Run(fmt.Sprintf("%s %s: %s", version.String(), dbKind.String(), tc.name), func() { + r := resolver.NewMockServiceResolver(s.controller) + r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr}) + + out, err := buildDSN(version, dbKind, &tc.in, r) + if tc.expectInvalidConfig { + s.Error(err, "Expected an invalid configuration error") + } else { + s.NoError(err) + } + s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path") + tokens := strings.Split(out, "?") + s.Equal(2, len(tokens), "invalid url") + qry, err := url.Parse("?" + tokens[1]) + s.NoError(err) + wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal) + s.Equal(wantAttrs, qry.Query(), "invalid dsn url params") + }) + } + } } } +func (s *sessionTestSuite) Test_MySQL8_Visibility_DoesntSupport_interpolateParams() { + config := config.SQL{ + User: "test", + Password: "pass", + ConnectProtocol: "tcp", + ConnectAddr: "192.168.0.1:3306", + DatabaseName: "db1", + ConnectAttributes: map[string]string{"interpolateParams": "ignored"}, + } + r := resolver.NewMockServiceResolver(s.controller) + r.EXPECT().Resolve(config.ConnectAddr).Return([]string{config.ConnectAddr}) + _, err := buildDSN(MySQLVersion8_0, sqlplugin.DbKindVisibility, &config, r) + s.Error(err, "We should return an error when a MySQL8 Visibility database is configured with interpolateParams") +} + func buildExpectedURLParams(attrs map[string]string, isolationKey string, isolationValue string) url.Values { result := make(map[string][]string, len(dsnAttrOverrides)+len(attrs)+1) for k, v := range attrs {