Skip to content

Commit

Permalink
*: don't use DSN to avoid some security problems (#38342)
Browse files Browse the repository at this point in the history
  • Loading branch information
lance6716 authored Oct 19, 2022
1 parent 22b85b9 commit d037637
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 130 deletions.
20 changes: 18 additions & 2 deletions br/pkg/lightning/checkpoints/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) {

switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := common.ConnectMySQL(cfg.Checkpoint.DSN)
var (
db *sql.DB
err error
)
if cfg.Checkpoint.MySQLParam != nil {
db, err = cfg.Checkpoint.MySQLParam.Connect()
} else {
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
}
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -546,7 +554,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error
}
switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := sql.Open("mysql", cfg.Checkpoint.DSN)
var (
db *sql.DB
err error
)
if cfg.Checkpoint.MySQLParam != nil {
db, err = cfg.Checkpoint.MySQLParam.Connect()
} else {
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
}
if err != nil {
return false, errors.Trace(err)
}
Expand Down
51 changes: 28 additions & 23 deletions br/pkg/lightning/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -58,28 +57,38 @@ type MySQLConnectParam struct {
Vars map[string]string
}

func (param *MySQLConnectParam) ToDSN() string {
hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s",
param.User, param.Password, hostPort,
param.SQLMode, param.MaxAllowedPacket, param.TLS)
func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config {
cfg := mysql.NewConfig()
cfg.Params = make(map[string]string)

cfg.User = param.User
cfg.Passwd = param.Password
cfg.Net = "tcp"
cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
cfg.Params["charset"] = "utf8mb4"
cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode)
cfg.MaxAllowedPacket = int(param.MaxAllowedPacket)
cfg.TLSConfig = param.TLS

for k, v := range param.Vars {
dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v))
cfg.Params[k] = fmt.Sprintf("'%s'", v)
}

return dsn
return cfg
}

func tryConnectMySQL(dsn string) (*sql.DB, error) {
driverName := "mysql"
failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) {
driverName = val.(string)
func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) {
pwd := val.(string)
if cfg.Passwd != pwd {
failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"})
}
failpoint.Return(nil, nil)
})
db, err := sql.Open(driverName, dsn)
c, err := mysql.NewConnector(cfg)
if err != nil {
return nil, errors.Trace(err)
}
db := sql.OpenDB(c)
if err = db.Ping(); err != nil {
_ = db.Close()
return nil, errors.Trace(err)
Expand All @@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) {

// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
// we will try to connect MySQL with the base64 decoding of the password.
func ConnectMySQL(dsn string) (*sql.DB, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, errors.Trace(err)
}
func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
// Try plain password first.
db, firstErr := tryConnectMySQL(dsn)
db, firstErr := tryConnectMySQL(cfg)
if firstErr == nil {
return db, nil
}
Expand All @@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
// If password is encoded by base64, try the decoded string as well.
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
cfg.Passwd = string(password)
db, err = tryConnectMySQL(cfg.FormatDSN())
db2, err := tryConnectMySQL(cfg)
if err == nil {
return db, nil
return db2, nil
}
}
}
Expand All @@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
}

func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
db, err := ConnectMySQL(param.ToDSN())
db, err := ConnectMySQL(param.ToDriverConfig())
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
69 changes: 5 additions & 64 deletions br/pkg/lightning/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@ package common_test

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

Expand All @@ -35,7 +31,6 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/log"
tmysql "github.com/pingcap/tidb/errno"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) {
require.Regexp(t, ".*http status code != 200.*", err.Error())
}

func TestToDSN(t *testing.T) {
param := common.MySQLConnectParam{
Host: "127.0.0.1",
Port: 4000,
User: "root",
Password: "123456",
SQLMode: "strict",
MaxAllowedPacket: 1234,
TLS: "cluster",
Vars: map[string]string{
"tidb_distsql_scan_concurrency": "1",
},
}
require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())

param.Host = "::1"
require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
}

type mockDriver struct {
driver.Driver
plainPsw string
}

func (m *mockDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
accessDenied := cfg.Passwd != m.plainPsw
return &mockConn{accessDenied: accessDenied}, nil
}

type mockConn struct {
driver.Conn
driver.Pinger
accessDenied bool
}

func (c *mockConn) Ping(ctx context.Context) error {
if c.accessDenied {
return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}
}
return nil
}

func (c *mockConn) Close() error {
return nil
}

func TestConnect(t *testing.T) {
plainPsw := "dQAUoDiyb1ucWZk7"
driverName := "mysql-mock-" + strconv.Itoa(rand.Int())
sql.Register(driverName, &mockDriver{plainPsw: plainPsw})

require.NoError(t, failpoint.Enable(
"github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver",
fmt.Sprintf("return(\"%s\")", driverName)))
"github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword",
fmt.Sprintf("return(\"%s\")", plainPsw)))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword"))
}()

param := common.MySQLConnectParam{
Expand All @@ -155,13 +98,11 @@ func TestConnect(t *testing.T) {
SQLMode: "strict",
MaxAllowedPacket: 1234,
}
db, err := param.Connect()
_, err := param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw))
db, err = param.Connect()
_, err = param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
}

func TestIsContextCanceledError(t *testing.T) {
Expand Down
13 changes: 7 additions & 6 deletions br/pkg/lightning/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,12 @@ type TikvImporter struct {
}

type Checkpoint struct {
Schema string `toml:"schema" json:"schema"`
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
Driver string `toml:"driver" json:"driver"`
Enable bool `toml:"enable" json:"enable"`
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
Schema string `toml:"schema" json:"schema"`
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN.
Driver string `toml:"driver" json:"driver"`
Enable bool `toml:"enable" json:"enable"`
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
}

type Cron struct {
Expand Down Expand Up @@ -1142,7 +1143,7 @@ func (cfg *Config) AdjustCheckPoint() {
MaxAllowedPacket: defaultMaxAllowedPacket,
TLS: cfg.TiDB.TLS,
}
cfg.Checkpoint.DSN = param.ToDSN()
cfg.Checkpoint.MySQLParam = &param
case CheckpointDriverFile:
cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb"
}
Expand Down
5 changes: 3 additions & 2 deletions br/pkg/lightning/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/BurntSushi/toml"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/config"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -626,7 +625,9 @@ func TestLoadConfig(t *testing.T) {
taskCfg.TiDB.DistSQLScanConcurrency = 1
err = taskCfg.Adjust(context.Background())
require.NoError(t, err)
require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN)
equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN()
expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27"
require.Equal(t, expectedDSN, equivalentDSN)

result := taskCfg.String()
require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result)
Expand Down
15 changes: 10 additions & 5 deletions cmd/importer/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"strconv"
"strings"

_ "github.com/go-sql-driver/mysql"
mysql2 "github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -318,13 +318,18 @@ func execSQL(db *sql.DB, sql string) error {
}

func createDB(cfg DBConfig) (*sql.DB, error) {
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name)
db, err := sql.Open("mysql", dbDSN)
driverCfg := mysql2.NewConfig()
driverCfg.User = cfg.User
driverCfg.Passwd = cfg.Password
driverCfg.Net = "tcp"
driverCfg.Addr = cfg.Host + ":" + strconv.Itoa(cfg.Port)
driverCfg.DBName = cfg.Name

c, err := mysql2.NewConnector(driverCfg)
if err != nil {
return nil, errors.Trace(err)
}

return db, nil
return sql.OpenDB(c), nil
}

func closeDB(db *sql.DB) error {
Expand Down
25 changes: 25 additions & 0 deletions dumpling/export/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,31 @@ func (conf *Config) GetDSN(db string) string {
return dsn
}

// GetDriverConfig returns the MySQL driver config from Config.
func (conf *Config) GetDriverConfig(db string) *mysql.Config {
driverCfg := mysql.NewConfig()
// maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection.
// https://github.com/go-sql-driver/mysql#maxallowedpacket
hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
driverCfg.User = conf.User
driverCfg.Passwd = conf.Password
driverCfg.Net = "tcp"
driverCfg.Addr = hostPort
driverCfg.DBName = db
driverCfg.Collation = "utf8mb4_general_ci"
driverCfg.ReadTimeout = conf.ReadTimeout
driverCfg.WriteTimeout = 30 * time.Second
driverCfg.InterpolateParams = true
driverCfg.MaxAllowedPacket = 0
if conf.Security.DriveTLSName != "" {
driverCfg.TLSConfig = conf.Security.DriveTLSName
}
if conf.AllowCleartextPasswords {
driverCfg.AllowCleartextPasswords = true
}
return driverCfg
}

func timestampDirName() string {
return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339))
}
Expand Down
Loading

0 comments on commit d037637

Please # to comment.