diff --git a/common/persistence/sql/sqlplugin/interfaces.go b/common/persistence/sql/sqlplugin/interfaces.go index 649fbb7f9ac3..b702dd3f58cd 100644 --- a/common/persistence/sql/sqlplugin/interfaces.go +++ b/common/persistence/sql/sqlplugin/interfaces.go @@ -137,3 +137,14 @@ type ( PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) } ) + +func (k DbKind) String() string { + switch k { + case DbKindMain: + return "main" + case DbKindVisibility: + return "visibility" + default: + return "unknown" + } +} diff --git a/common/persistence/sql/sqlplugin/mysql/plugin.go b/common/persistence/sql/sqlplugin/mysql/plugin.go index b93b2fe42b47..40262eeef7e9 100644 --- a/common/persistence/sql/sqlplugin/mysql/plugin.go +++ b/common/persistence/sql/sqlplugin/mysql/plugin.go @@ -53,7 +53,7 @@ func (p *plugin) CreateDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.DB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion5_7, dbKind, cfg, r) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (p *plugin) CreateAdminDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.AdminDB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion5_7, sqlplugin.DbKindMain, cfg, r) if err != nil { return nil, err } @@ -80,10 +80,12 @@ func (p *plugin) CreateAdminDB( // SQL database and the object can be used to perform CRUD operations on // the tables in the database func (p *plugin) createDBConnection( + version session.MySQLVersion, + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*sqlx.DB, error) { - mysqlSession, err := session.NewSession(cfg, resolver) + mysqlSession, err := session.NewSession(version, dbKind, cfg, resolver) if err != nil { return nil, err } diff --git a/common/persistence/sql/sqlplugin/mysql/plugin_v8.go b/common/persistence/sql/sqlplugin/mysql/plugin_v8.go index 6146fbb864ea..f89fed218edf 100644 --- a/common/persistence/sql/sqlplugin/mysql/plugin_v8.go +++ b/common/persistence/sql/sqlplugin/mysql/plugin_v8.go @@ -28,6 +28,7 @@ import ( "go.temporal.io/server/common/config" "go.temporal.io/server/common/persistence/sql" "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql/session" "go.temporal.io/server/common/resolver" ) @@ -52,7 +53,7 @@ func (p *pluginV8) CreateDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.DB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion8_0, dbKind, cfg, r) if err != nil { return nil, err } @@ -66,7 +67,7 @@ func (p *pluginV8) CreateAdminDB( cfg *config.SQL, r resolver.ServiceResolver, ) (sqlplugin.AdminDB, error) { - conn, err := p.createDBConnection(cfg, r) + conn, err := p.createDBConnection(session.MySQLVersion8_0, sqlplugin.DbKindMain, cfg, r) if err != nil { return nil, err } diff --git a/common/persistence/sql/sqlplugin/mysql/session/session.go b/common/persistence/sql/sqlplugin/mysql/session/session.go index b9490d5043cf..58a8f33a5169 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session.go @@ -27,6 +27,7 @@ package session import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "os" "strings" @@ -37,9 +38,25 @@ import ( "go.temporal.io/server/common/auth" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) +type ( + Session struct { + *sqlx.DB + } + + // MySQLVersion specifies which of the distinct mysql versions we support + MySQLVersion int +) + +const ( + MySQLVersionUnspecified MySQLVersion = iota + MySQLVersion5_7 + MySQLVersion8_0 +) + const ( driverName = "mysql" @@ -48,22 +65,40 @@ const ( defaultIsolationLevel = "'READ-COMMITTED'" // customTLSName is the name used if a custom tls configuration is created customTLSName = "tls-custom" + + interpolateParamsAttr = "interpolateParams" ) -var dsnAttrOverrides = map[string]string{ - "parseTime": "true", - "clientFoundRows": "true", -} +var ( + errMySQL8VisInterpolateParamsNotSupported = errors.New("interpolateParams is not supported for mysql8 visibility stores") + dsnAttrOverrides = map[string]string{ + "parseTime": "true", + "clientFoundRows": "true", + } +) -type Session struct { - *sqlx.DB +func (m MySQLVersion) String() string { + switch m { + case MySQLVersion5_7: + return "MySQL 5.7" + case MySQLVersion8_0: + return "MySQL 8.0" + default: + return "Unspecified" + } } func NewSession( + version MySQLVersion, + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*Session, error) { - db, err := createConnection(cfg, resolver) + if version == MySQLVersionUnspecified { + return nil, fmt.Errorf("Bug: unspecified MySQL version provided to NewSession") + } + + db, err := createConnection(version, dbKind, cfg, resolver) if err != nil { return nil, err } @@ -77,6 +112,8 @@ func (s *Session) Close() { } func createConnection( + version MySQLVersion, + dbKind sqlplugin.DbKind, cfg *config.SQL, resolver resolver.ServiceResolver, ) (*sqlx.DB, error) { @@ -85,7 +122,12 @@ func createConnection( return nil, err } - db, err := sqlx.Connect(driverName, buildDSN(cfg, resolver)) + dsn, err := buildDSN(version, dbKind, cfg, resolver) + if err != nil { + return nil, err + } + + db, err := sqlx.Connect(driverName, dsn) if err != nil { return nil, err } @@ -104,7 +146,12 @@ func createConnection( return db, nil } -func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { +func buildDSN( + version MySQLVersion, + dbKind sqlplugin.DbKind, + cfg *config.SQL, + r resolver.ServiceResolver, +) (string, error) { mysqlConfig := mysql.NewConfig() mysqlConfig.User = cfg.User @@ -112,7 +159,11 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { mysqlConfig.Addr = r.Resolve(cfg.ConnectAddr)[0] mysqlConfig.DBName = cfg.DatabaseName mysqlConfig.Net = cfg.ConnectProtocol - mysqlConfig.Params = buildDSNAttrs(cfg) + var err error + mysqlConfig.Params, err = buildDSNAttrs(version, dbKind, cfg) + if err != nil { + return "", err + } // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L104-L106 // https://github.com/go-sql-driver/mysql/blob/v1.5.0/dsn.go#L182-L189 @@ -124,10 +175,10 @@ func buildDSN(cfg *config.SQL, r resolver.ServiceResolver) string { // https://github.com/temporalio/temporal/issues/1703 mysqlConfig.RejectReadOnly = true - return mysqlConfig.FormatDSN() + return mysqlConfig.FormatDSN(), nil } -func buildDSNAttrs(cfg *config.SQL) map[string]string { +func buildDSNAttrs(version MySQLVersion, dbKind sqlplugin.DbKind, cfg *config.SQL) (map[string]string, error) { attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1) for k, v := range cfg.ConnectAttributes { k1, v1 := sanitizeAttr(k, v) @@ -145,7 +196,13 @@ func buildDSNAttrs(cfg *config.SQL) map[string]string { attrs[k] = v } - return attrs + if version == MySQLVersion8_0 && dbKind == sqlplugin.DbKindVisibility { + if _, ok := attrs[interpolateParamsAttr]; ok { + return nil, errMySQL8VisInterpolateParamsNotSupported + } + } + + return attrs, nil } func hasAttr(attrs map[string]string, key string) bool { diff --git a/common/persistence/sql/sqlplugin/mysql/session/session_test.go b/common/persistence/sql/sqlplugin/mysql/session/session_test.go index 6343daecba8e..8cb6bd56c05f 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session_test.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session_test.go @@ -25,6 +25,7 @@ package session import ( + "fmt" "net/url" "strings" "testing" @@ -33,6 +34,7 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) @@ -66,12 +68,15 @@ func (s *sessionTestSuite) TearDownTest() { func (s *sessionTestSuite) TestBuildDSN() { testCases := []struct { - in config.SQL - outURLPath string - outIsolationKey string - outIsolationVal string + name string + in config.SQL + outURLPath string + outIsolationKey string + outIsolationVal string + expectInvalidConfig bool }{ { + name: "no connect attributes", in: config.SQL{ User: "test", Password: "pass", @@ -84,6 +89,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "with connect attributes", in: config.SQL{ User: "test", Password: "pass", @@ -97,6 +103,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (quoted, shorthand)", in: config.SQL{ User: "test", Password: "pass", @@ -110,6 +117,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (unquoted, shorthand)", in: config.SQL{ User: "test", Password: "pass", @@ -123,6 +131,7 @@ func (s *sessionTestSuite) TestBuildDSN() { outURLPath: "test:pass@tcp(192.168.0.1:3306)/db1?", }, { + name: "override isolation level (unquoted, full name)", in: config.SQL{ User: "test", Password: "pass", @@ -137,21 +146,47 @@ func (s *sessionTestSuite) TestBuildDSN() { }, } - for _, tc := range testCases { - r := resolver.NewMockServiceResolver(s.controller) - r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr}) - - out := buildDSN(&tc.in, r) - s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path") - tokens := strings.Split(out, "?") - s.Equal(2, len(tokens), "invalid url") - qry, err := url.Parse("?" + tokens[1]) - s.NoError(err) - wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal) - s.Equal(wantAttrs, qry.Query(), "invalid dsn url params") + for _, version := range []MySQLVersion{MySQLVersion5_7, MySQLVersion8_0} { + for _, dbKind := range []sqlplugin.DbKind{sqlplugin.DbKindMain, sqlplugin.DbKindVisibility} { + for _, tc := range testCases { + s.Run(fmt.Sprintf("%s %s: %s", version.String(), dbKind.String(), tc.name), func() { + r := resolver.NewMockServiceResolver(s.controller) + r.EXPECT().Resolve(tc.in.ConnectAddr).Return([]string{tc.in.ConnectAddr}) + + out, err := buildDSN(version, dbKind, &tc.in, r) + if tc.expectInvalidConfig { + s.Error(err, "Expected an invalid configuration error") + } else { + s.NoError(err) + } + s.True(strings.HasPrefix(out, tc.outURLPath), "invalid url path") + tokens := strings.Split(out, "?") + s.Equal(2, len(tokens), "invalid url") + qry, err := url.Parse("?" + tokens[1]) + s.NoError(err) + wantAttrs := buildExpectedURLParams(tc.in.ConnectAttributes, tc.outIsolationKey, tc.outIsolationVal) + s.Equal(wantAttrs, qry.Query(), "invalid dsn url params") + }) + } + } } } +func (s *sessionTestSuite) Test_MySQL8_Visibility_DoesntSupport_interpolateParams() { + config := config.SQL{ + User: "test", + Password: "pass", + ConnectProtocol: "tcp", + ConnectAddr: "192.168.0.1:3306", + DatabaseName: "db1", + ConnectAttributes: map[string]string{"interpolateParams": "ignored"}, + } + r := resolver.NewMockServiceResolver(s.controller) + r.EXPECT().Resolve(config.ConnectAddr).Return([]string{config.ConnectAddr}) + _, err := buildDSN(MySQLVersion8_0, sqlplugin.DbKindVisibility, &config, r) + s.Error(err, "We should return an error when a MySQL8 Visibility database is configured with interpolateParams") +} + func buildExpectedURLParams(attrs map[string]string, isolationKey string, isolationValue string) url.Values { result := make(map[string][]string, len(dsnAttrOverrides)+len(attrs)+1) for k, v := range attrs {