diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index f4fed78a..e4df35e0 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -133,5 +133,9 @@ func (mc *mockCapture) Capture(packet []byte, startTime time.Time, connID uint64 mc.connID = connID } +func (mc *mockCapture) Progress() (float64, error) { + return 0, nil +} + func (mc *mockCapture) Close() { } diff --git a/pkg/sqlreplay/capture/capture.go b/pkg/sqlreplay/capture/capture.go index 060241dc..a012c2ce 100644 --- a/pkg/sqlreplay/capture/capture.go +++ b/pkg/sqlreplay/capture/capture.go @@ -33,6 +33,8 @@ type Capture interface { Stop(err error) // Capture captures traffic Capture(packet []byte, startTime time.Time, connID uint64) + // Progress returns the progress of the capture job + Progress() (float64, error) // Close closes the capture Close() } @@ -202,6 +204,15 @@ func (c *capture) Capture(packet []byte, startTime time.Time, connID uint64) { } } +func (c *capture) Progress() (float64, error) { + c.Lock() + defer c.Unlock() + if c.startTime.IsZero() || c.cfg.Duration == 0 { + return 0, c.err + } + return float64(time.Since(c.startTime)) / float64(c.cfg.Duration), c.err +} + // stopNoLock must be called after holding a lock. func (c *capture) stopNoLock(err error) { // already stopped diff --git a/pkg/sqlreplay/capture/capture_test.go b/pkg/sqlreplay/capture/capture_test.go index 3ba5fd00..8c7dc3dd 100644 --- a/pkg/sqlreplay/capture/capture_test.go +++ b/pkg/sqlreplay/capture/capture_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/lib/util/waitgroup" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" @@ -34,7 +35,11 @@ func TestStartAndStop(t *testing.T) { // start capture and the traffic should be outputted require.NoError(t, cpt.Start(cfg)) cpt.Capture(packet, time.Now(), 100) - cpt.Stop(nil) + _, err := cpt.Progress() + require.NoError(t, err) + cpt.Stop(errors.Errorf("mock error")) + _, err = cpt.Progress() + require.ErrorContains(t, err, "mock error") cpt.wg.Wait() data := writer.getData() require.Greater(t, len(data), 0) diff --git a/pkg/sqlreplay/manager/job.go b/pkg/sqlreplay/manager/job.go new file mode 100644 index 00000000..00c44288 --- /dev/null +++ b/pkg/sqlreplay/manager/job.go @@ -0,0 +1,72 @@ +// Copyright 2024 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import ( + "encoding/json" + "time" + + "github.com/siddontang/go/hack" +) + +type jobType int + +const ( + Capture jobType = iota + Replay +) + +type Job interface { + Type() jobType + String() string + SetProgress(progress float64, err error) + IsRunning() bool +} + +type job struct { + StartTime time.Time + Duration time.Duration + Progress float64 + Error error +} + +func (job *job) IsRunning() bool { + return job.Error == nil +} + +// TODO: refine the output +func (job *job) String() string { + b, err := json.Marshal(job) + if err != nil { + return err.Error() + } + return hack.String(b) +} + +func (job *job) SetProgress(progress float64, err error) { + if progress > job.Progress { + job.Progress = progress + } + job.Error = err +} + +var _ Job = (*captureJob)(nil) + +type captureJob struct { + job +} + +func (job *captureJob) Type() jobType { + return Capture +} + +var _ Job = (*replayJob)(nil) + +type replayJob struct { + job +} + +func (job *replayJob) Type() jobType { + return Replay +} diff --git a/pkg/sqlreplay/manager/job_test.go b/pkg/sqlreplay/manager/job_test.go new file mode 100644 index 00000000..21fb0931 --- /dev/null +++ b/pkg/sqlreplay/manager/job_test.go @@ -0,0 +1,46 @@ +// Copyright 2024 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import ( + "testing" + "time" + + "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/stretchr/testify/require" +) + +func TestJobs(t *testing.T) { + tests := []struct { + job Job + tp jobType + }{ + { + &captureJob{ + job: job{ + StartTime: time.Now(), + Duration: 10 * time.Second, + }, + }, + Capture, + }, + { + &replayJob{ + job: job{ + StartTime: time.Now(), + Duration: 10 * time.Second, + }, + }, + Replay, + }, + } + + for i, test := range tests { + require.Equal(t, test.tp, test.job.Type(), "case %d", i) + require.True(t, test.job.IsRunning(), "case %d", i) + test.job.SetProgress(0.5, errors.New("stopped manually")) + require.False(t, test.job.IsRunning(), "case %d", i) + require.NotEmpty(t, test.job.String(), "case %d", i) + } +} diff --git a/pkg/sqlreplay/manager/manager.go b/pkg/sqlreplay/manager/manager.go new file mode 100644 index 00000000..7f9891e5 --- /dev/null +++ b/pkg/sqlreplay/manager/manager.go @@ -0,0 +1,155 @@ +// Copyright 2024 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import ( + "crypto/tls" + "encoding/json" + "time" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/pingcap/tiproxy/pkg/proxy/backend" + "github.com/pingcap/tiproxy/pkg/sqlreplay/capture" + "github.com/pingcap/tiproxy/pkg/sqlreplay/replay" + "github.com/siddontang/go/hack" + "go.uber.org/zap" +) + +type CertManager interface { + SQLTLS() *tls.Config +} + +type JobManager interface { + StartCapture(capture.CaptureConfig) error + StartReplay(replay.ReplayConfig) error + GetCapture() capture.Capture + Stop() string + Jobs() string + Close() +} + +var _ JobManager = (*jobManager)(nil) + +type jobManager struct { + jobHistory []Job + capture capture.Capture + replay replay.Replay + hsHandler backend.HandshakeHandler + certManager CertManager + cfg *config.Config + lg *zap.Logger +} + +func NewJobManager(lg *zap.Logger, cfg *config.Config, certMgr CertManager, hsHandler backend.HandshakeHandler) *jobManager { + return &jobManager{ + lg: lg, + capture: capture.NewCapture(lg.Named("capture")), + replay: replay.NewReplay(lg.Named("replay")), + hsHandler: hsHandler, + cfg: cfg, + certManager: certMgr, + } +} + +func (jm *jobManager) runningJob() Job { + if len(jm.jobHistory) == 0 { + return nil + } + job := jm.jobHistory[len(jm.jobHistory)-1] + if job.IsRunning() { + switch job.Type() { + case Capture: + progress, err := jm.capture.Progress() + job.SetProgress(progress, err) + case Replay: + progress, err := jm.replay.Progress() + job.SetProgress(progress, err) + } + if job.IsRunning() { + return job + } + } + return nil +} + +func (jm *jobManager) StartCapture(cfg capture.CaptureConfig) error { + running := jm.runningJob() + if running != nil { + return errors.Errorf("a job is running: %s", running.String()) + } + newJob := &captureJob{ + job: job{ + StartTime: time.Now(), + Duration: cfg.Duration, + }, + } + err := jm.capture.Start(cfg) + if err != nil { + newJob.SetProgress(0, err) + } + jm.jobHistory = append(jm.jobHistory, newJob) + return errors.Wrapf(err, "start capture failed") +} + +func (jm *jobManager) StartReplay(cfg replay.ReplayConfig) error { + running := jm.runningJob() + if running != nil { + return errors.Errorf("a job is running: %s", running.String()) + } + newJob := &replayJob{ + job: job{ + StartTime: time.Now(), + }, + } + // TODO: support update configs online + err := jm.replay.Start(cfg, jm.certManager.SQLTLS(), jm.hsHandler, &backend.BCConfig{ + ProxyProtocol: jm.cfg.Proxy.ProxyProtocol != "", + RequireBackendTLS: jm.cfg.Security.RequireBackendTLS, + HealthyKeepAlive: jm.cfg.Proxy.BackendHealthyKeepalive, + UnhealthyKeepAlive: jm.cfg.Proxy.BackendUnhealthyKeepalive, + ConnBufferSize: jm.cfg.Proxy.ConnBufferSize, + }) + if err != nil { + newJob.SetProgress(0, err) + } + jm.jobHistory = append(jm.jobHistory, newJob) + return errors.Wrapf(err, "start replay failed") +} + +func (jm *jobManager) GetCapture() capture.Capture { + return jm.capture +} + +func (jm *jobManager) Jobs() string { + b, err := json.Marshal(jm.jobHistory) + if err != nil { + return err.Error() + } + return hack.String(b) +} + +func (jm *jobManager) Stop() string { + job := jm.runningJob() + if job == nil { + return "no job running" + } + switch job.Type() { + case Capture: + jm.capture.Stop(errors.Errorf("manually stopped")) + case Replay: + jm.replay.Stop(errors.Errorf("manually stopped")) + } + job.SetProgress(0, errors.Errorf("manually stopped")) + return "stopped: " + job.String() +} + +func (jm *jobManager) Close() { + if jm.capture != nil { + jm.capture.Close() + } + if jm.replay != nil { + jm.replay.Close() + } +} diff --git a/pkg/sqlreplay/manager/manager_test.go b/pkg/sqlreplay/manager/manager_test.go new file mode 100644 index 00000000..273da97e --- /dev/null +++ b/pkg/sqlreplay/manager/manager_test.go @@ -0,0 +1,41 @@ +// Copyright 2024 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import ( + "testing" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/pkg/sqlreplay/capture" + "github.com/pingcap/tiproxy/pkg/sqlreplay/replay" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestStartAndStop(t *testing.T) { + mgr := NewJobManager(zap.NewNop(), &config.Config{}, &mockCertMgr{}, nil) + defer mgr.Close() + mgr.capture = &mockCapture{} + mgr.replay = &mockReplay{} + + require.Contains(t, mgr.Stop(), "no job running") + require.NotNil(t, mgr.GetCapture()) + + require.NoError(t, mgr.StartCapture(capture.CaptureConfig{})) + require.Error(t, mgr.StartCapture(capture.CaptureConfig{})) + require.Error(t, mgr.StartReplay(replay.ReplayConfig{})) + require.Len(t, mgr.jobHistory, 1) + require.NotEmpty(t, mgr.Jobs()) + require.Contains(t, mgr.Stop(), "stopped") + require.Contains(t, mgr.Stop(), "no job running") + require.Len(t, mgr.jobHistory, 1) + + require.NoError(t, mgr.StartReplay(replay.ReplayConfig{})) + require.Error(t, mgr.StartCapture(capture.CaptureConfig{})) + require.Error(t, mgr.StartReplay(replay.ReplayConfig{})) + require.Len(t, mgr.jobHistory, 2) + require.Contains(t, mgr.Stop(), "stopped") + require.Contains(t, mgr.Stop(), "no job running") + require.Len(t, mgr.jobHistory, 2) +} diff --git a/pkg/sqlreplay/manager/mock_test.go b/pkg/sqlreplay/manager/mock_test.go new file mode 100644 index 00000000..bbaf149f --- /dev/null +++ b/pkg/sqlreplay/manager/mock_test.go @@ -0,0 +1,63 @@ +// Copyright 2024 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import ( + "crypto/tls" + "time" + + "github.com/pingcap/tiproxy/pkg/proxy/backend" + "github.com/pingcap/tiproxy/pkg/sqlreplay/capture" + "github.com/pingcap/tiproxy/pkg/sqlreplay/replay" +) + +var _ CertManager = (*mockCertMgr)(nil) + +type mockCertMgr struct { +} + +func (mockCertMgr) SQLTLS() *tls.Config { + return nil +} + +var _ capture.Capture = (*mockCapture)(nil) + +type mockCapture struct { +} + +func (m *mockCapture) Capture(packet []byte, startTime time.Time, connID uint64) { +} + +func (m *mockCapture) Close() { +} + +func (m *mockCapture) Progress() (float64, error) { + return 0, nil +} + +func (m *mockCapture) Stop(err error) { +} + +func (mockCapture) Start(capture.CaptureConfig) error { + return nil +} + +var _ replay.Replay = (*mockReplay)(nil) + +type mockReplay struct { +} + +func (m *mockReplay) Close() { +} + +func (m *mockReplay) Progress() (float64, error) { + return 0, nil +} + +func (m *mockReplay) Start(cfg replay.ReplayConfig, backendTLSConfig *tls.Config, hsHandler backend.HandshakeHandler, bcConfig *backend.BCConfig) error { + return nil +} + +func (m *mockReplay) Stop(err error) { +} diff --git a/pkg/sqlreplay/replay/replay.go b/pkg/sqlreplay/replay/replay.go index 01f3c66d..97bfa792 100644 --- a/pkg/sqlreplay/replay/replay.go +++ b/pkg/sqlreplay/replay/replay.go @@ -31,9 +31,11 @@ const ( type Replay interface { // Start starts the replay - Start(cfg ReplayConfig) error + Start(cfg ReplayConfig, backendTLSConfig *tls.Config, hsHandler backend.HandshakeHandler, bcConfig *backend.BCConfig) error // Stop stops the replay Stop(err error) + // Progress returns the progress of the replay job + Progress() (float64, error) // Close closes the replay Close() } @@ -43,7 +45,10 @@ type ReplayConfig struct { Username string Password string Speed float64 - reader cmd.LineReader + // the following fields are for testing + reader cmd.LineReader + report report.Report + connCreator conn.ConnCreator } func (cfg *ReplayConfig) Validate() error { @@ -85,25 +90,13 @@ type replay struct { connCount int } -func NewReplay(lg *zap.Logger, backendTLSConfig *tls.Config, hsHandler backend.HandshakeHandler, bcConfig *backend.BCConfig) *replay { - r := &replay{ - conns: make(map[uint64]conn.Conn), - lg: lg, - exceptionCh: make(chan conn.Exception, maxPendingExceptions), - closeCh: make(chan uint64, maxPendingExceptions), +func NewReplay(lg *zap.Logger) *replay { + return &replay{ + lg: lg, } - r.connCreator = func(connID uint64) conn.Conn { - return conn.NewConn(lg, r.cfg.Username, r.cfg.Password, backendTLSConfig, hsHandler, connID, bcConfig, r.exceptionCh, r.closeCh) - } - backendConnCreator := func() conn.BackendConn { - // TODO: allocate connection ID. - return conn.NewBackendConn(lg.Named("be"), 1, hsHandler, bcConfig, backendTLSConfig, r.cfg.Username, r.cfg.Password) - } - r.report = report.NewReport(lg.Named("report"), r.exceptionCh, backendConnCreator) - return r } -func (r *replay) Start(cfg ReplayConfig) error { +func (r *replay) Start(cfg ReplayConfig, backendTLSConfig *tls.Config, hsHandler backend.HandshakeHandler, bcConfig *backend.BCConfig) error { if err := cfg.Validate(); err != nil { return err } @@ -111,6 +104,24 @@ func (r *replay) Start(cfg ReplayConfig) error { r.Lock() defer r.Unlock() r.cfg = cfg + r.conns = make(map[uint64]conn.Conn) + r.exceptionCh = make(chan conn.Exception, maxPendingExceptions) + r.closeCh = make(chan uint64, maxPendingExceptions) + r.connCreator = cfg.connCreator + if r.connCreator == nil { + r.connCreator = func(connID uint64) conn.Conn { + return conn.NewConn(r.lg.Named("conn"), r.cfg.Username, r.cfg.Password, backendTLSConfig, hsHandler, connID, bcConfig, r.exceptionCh, r.closeCh) + } + } + r.report = cfg.report + if r.report == nil { + backendConnCreator := func() conn.BackendConn { + // TODO: allocate connection ID. + return conn.NewBackendConn(r.lg.Named("be"), 1, hsHandler, bcConfig, backendTLSConfig, r.cfg.Username, r.cfg.Password) + } + r.report = report.NewReport(r.lg.Named("report"), r.exceptionCh, backendConnCreator) + } + childCtx, cancel := context.WithCancel(context.Background()) r.cancel = cancel if err := r.report.Start(childCtx, report.ReportConfig{ @@ -225,6 +236,10 @@ func (r *replay) readCloseCh(ctx context.Context) { } } +func (r *replay) Progress() (float64, error) { + return 0, r.err +} + func (r *replay) Stop(err error) { r.Lock() defer r.Unlock() diff --git a/pkg/sqlreplay/replay/replay_test.go b/pkg/sqlreplay/replay/replay_test.go index b8ddfbb8..c3ce2fd4 100644 --- a/pkg/sqlreplay/replay/replay_test.go +++ b/pkg/sqlreplay/replay/replay_test.go @@ -18,22 +18,24 @@ import ( func TestManageConns(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) - replay := NewReplay(lg, nil, nil, &backend.BCConfig{}) + replay := NewReplay(lg) defer replay.Close() - replay.connCreator = func(connID uint64) conn.Conn { - return &mockConn{ - connID: connID, - exceptionCh: replay.exceptionCh, - closeCh: replay.closeCh, - } - } - replay.report = newMockReport(replay.exceptionCh) + loader := newMockChLoader() - require.NoError(t, replay.Start(ReplayConfig{ + cfg := ReplayConfig{ Input: t.TempDir(), Username: "u1", reader: loader, - })) + connCreator: func(connID uint64) conn.Conn { + return &mockConn{ + connID: connID, + exceptionCh: replay.exceptionCh, + closeCh: replay.closeCh, + } + }, + report: newMockReport(replay.exceptionCh), + } + require.NoError(t, replay.Start(cfg, nil, nil, &backend.BCConfig{})) command := newMockCommand(1) loader.writeCommand(command) @@ -99,32 +101,34 @@ func TestReplaySpeed(t *testing.T) { speeds := []float64{10, 1, 0.1} var lastTotalTime time.Duration + replay := NewReplay(lg) + defer replay.Close() for _, speed := range speeds { - replay := NewReplay(lg, nil, nil, &backend.BCConfig{}) cmdCh := make(chan *cmd.Command, 10) - replay.connCreator = func(connID uint64) conn.Conn { - return &mockConn{ - connID: connID, - cmdCh: cmdCh, - exceptionCh: replay.exceptionCh, - closeCh: replay.closeCh, - } + loader := newMockNormalLoader() + cfg := ReplayConfig{ + Input: t.TempDir(), + Username: "u1", + Speed: speed, + reader: loader, + report: newMockReport(replay.exceptionCh), + connCreator: func(connID uint64) conn.Conn { + return &mockConn{ + connID: connID, + cmdCh: cmdCh, + exceptionCh: replay.exceptionCh, + closeCh: replay.closeCh, + } + }, } - replay.report = newMockReport(replay.exceptionCh) - loader := newMockNormalLoader() now := time.Now() for i := 0; i < 10; i++ { command := newMockCommand(1) command.StartTs = now.Add(time.Duration(i*10) * time.Millisecond) loader.writeCommand(command) } - require.NoError(t, replay.Start(ReplayConfig{ - Input: t.TempDir(), - Username: "u1", - reader: loader, - Speed: speed, - })) + require.NoError(t, replay.Start(cfg, nil, nil, &backend.BCConfig{})) var firstTime, lastTime time.Time for i := 0; i < 10; i++ { @@ -137,12 +141,15 @@ func TestReplaySpeed(t *testing.T) { interval := now.Sub(lastTime) lastTime = now t.Logf("speed: %f, i: %d, interval: %s", speed, i, interval) + // CI is too unstable, comment this. // require.Greater(t, interval, time.Duration(float64(10*time.Millisecond)/speed)/2, "speed: %f, i: %d", speed, i) } } totalTime := lastTime.Sub(firstTime) require.Greater(t, totalTime, lastTotalTime, "speed: %f", speed) lastTotalTime = totalTime - replay.Close() + + replay.Stop(nil) + loader.Close() } }