diff --git a/common/persistence/dataInterfaces.go b/common/persistence/dataInterfaces.go index b10ac105143..30514b65c86 100644 --- a/common/persistence/dataInterfaces.go +++ b/common/persistence/dataInterfaces.go @@ -1049,13 +1049,13 @@ type ( // AppendHistoryNodes add a node to history node table AppendHistoryNodes(ctx context.Context, request *AppendHistoryNodesRequest) (*AppendHistoryNodesResponse, error) - // AppendRawHistoryNodes add a node of raw histories to history ndoe table + // AppendRawHistoryNodes add a node of raw histories to history node table AppendRawHistoryNodes(ctx context.Context, request *AppendRawHistoryNodesRequest) (*AppendHistoryNodesResponse, error) // ReadHistoryBranch returns history node data for a branch ReadHistoryBranch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadHistoryBranchResponse, error) // ReadHistoryBranchByBatch returns history node data for a branch ByBatch ReadHistoryBranchByBatch(ctx context.Context, request *ReadHistoryBranchRequest) (*ReadHistoryBranchByBatchResponse, error) - // ReadHistoryBranch returns history node data for a branch + // ReadHistoryBranchReverse returns history node data for a branch ReadHistoryBranchReverse(ctx context.Context, request *ReadHistoryBranchReverseRequest) (*ReadHistoryBranchReverseResponse, error) // ReadRawHistoryBranch returns history node raw data for a branch ByBatch // NOTE: this API should only be used by 3+DC diff --git a/host/xdc/integration_failover_test.go b/host/xdc/integration_failover_test.go index d3769401a22..ca8813b0cb8 100644 --- a/host/xdc/integration_failover_test.go +++ b/host/xdc/integration_failover_test.go @@ -2419,6 +2419,110 @@ func (s *integrationClustersTestSuite) TestForceMigration_ClosedWorkflow() { s.Equal(enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED, descResp.GetWorkflowExecutionInfo().Status) } +func (s *integrationClustersTestSuite) TestForceMigration_ResetWorkflow() { + testCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + namespace := "force-replication" + common.GenerateRandomString(5) + s.registerNamespace(namespace, true) + + taskqueue := "integration-force-replication-reset-task-queue" + client1, worker1 := s.newClientAndWorker(s.cluster1.GetHost().FrontendGRPCAddress(), namespace, taskqueue, "worker1") + + testWorkflowFn := func(ctx workflow.Context) error { + return nil + } + + worker1.RegisterWorkflow(testWorkflowFn) + worker1.Start() + + // Start wf1 + workflowID := "force-replication-test-reset-wf-1" + run1, err := client1.ExecuteWorkflow(testCtx, sdkclient.StartWorkflowOptions{ + ID: workflowID, + TaskQueue: taskqueue, + WorkflowRunTimeout: time.Second * 30, + }, testWorkflowFn) + + s.NoError(err) + s.NotEmpty(run1.GetRunID()) + s.logger.Info("start wf1", tag.WorkflowRunID(run1.GetRunID())) + // wait until wf1 complete + err = run1.Get(testCtx, nil) + s.NoError(err) + + resp, err := client1.ResetWorkflowExecution(testCtx, &workflowservice.ResetWorkflowExecutionRequest{ + Namespace: namespace, + WorkflowExecution: &commonpb.WorkflowExecution{ + WorkflowId: workflowID, + RunId: run1.GetRunID(), + }, + Reason: "test", + WorkflowTaskFinishEventId: 3, + RequestId: uuid.New(), + }) + s.NoError(err) + resetRun := client1.GetWorkflow(testCtx, workflowID, resp.GetRunId()) + err = resetRun.Get(testCtx, nil) + s.NoError(err) + + frontendClient1 := s.cluster1.GetFrontendClient() + // Update ns to have 2 clusters + _, err = frontendClient1.UpdateNamespace(testCtx, &workflowservice.UpdateNamespaceRequest{ + Namespace: namespace, + ReplicationConfig: &replicationpb.NamespaceReplicationConfig{ + Clusters: clusterReplicationConfig, + }, + }) + s.NoError(err) + + // Wait for ns cache to pick up the change + time.Sleep(cacheRefreshInterval) + + nsResp, err := frontendClient1.DescribeNamespace(testCtx, &workflowservice.DescribeNamespaceRequest{ + Namespace: namespace, + }) + s.NoError(err) + s.True(nsResp.IsGlobalNamespace) + s.Equal(2, len(nsResp.ReplicationConfig.Clusters)) + + // Start force-replicate wf + sysClient, err := sdkclient.Dial(sdkclient.Options{ + HostPort: s.cluster1.GetHost().FrontendGRPCAddress(), + Namespace: "temporal-system", + }) + forceReplicationWorkflowID := "force-replication-wf" + sysWfRun, err := sysClient.ExecuteWorkflow(testCtx, sdkclient.StartWorkflowOptions{ + ID: forceReplicationWorkflowID, + TaskQueue: sw.DefaultWorkerTaskQueue, + WorkflowRunTimeout: time.Second * 30, + }, "force-replication", migration.ForceReplicationParams{ + Namespace: namespace, + OverallRps: 10, + }) + s.NoError(err) + err = sysWfRun.Get(testCtx, nil) + s.NoError(err) + + // Verify all wf in ns is now available in cluster2 + client2, _ := s.newClientAndWorker(s.cluster2.GetHost().FrontendGRPCAddress(), namespace, taskqueue, "worker2") + verifyHistory := func(wfID string, runID string) { + iter1 := client1.GetWorkflowHistory(testCtx, wfID, runID, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + iter2 := client2.GetWorkflowHistory(testCtx, wfID, runID, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + for iter1.HasNext() && iter2.HasNext() { + event1, err := iter1.Next() + s.NoError(err) + event2, err := iter2.Next() + s.NoError(err) + s.Equal(event1, event2) + } + s.False(iter1.HasNext()) + s.False(iter2.HasNext()) + } + verifyHistory(workflowID, run1.GetRunID()) + verifyHistory(workflowID, resp.GetRunId()) +} + func (s *integrationClustersTestSuite) getHistory(client host.FrontendClient, namespace string, execution *commonpb.WorkflowExecution) []*historypb.HistoryEvent { historyResponse, err := client.GetWorkflowExecutionHistory(host.NewContext(), &workflowservice.GetWorkflowExecutionHistoryRequest{ Namespace: namespace, diff --git a/service/history/nDCHistoryReplicator.go b/service/history/nDCHistoryReplicator.go index f7a3400305d..0429761ed71 100644 --- a/service/history/nDCHistoryReplicator.go +++ b/service/history/nDCHistoryReplicator.go @@ -26,16 +26,20 @@ package history import ( "context" + "fmt" + "sort" "time" "github.com/pborman/uuid" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" + historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/adminservice/v1" enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/collection" @@ -831,7 +835,28 @@ func (r *nDCHistoryReplicatorImpl) backfillHistory( lastEventVersion int64, branchToken []byte, ) (*time.Time, error) { - historyIterator := collection.NewPagingIterator(r.getHistoryPaginationFn( + + // Get the last batch node id to check if the history data is already in DB. + localHistoryIterator := collection.NewPagingIterator(r.getHistoryFromLocalPaginationFn( + ctx, + branchToken, + lastEventID, + )) + var lastBatchNodeID int64 + for localHistoryIterator.HasNext() { + localHistoryBatch, err := localHistoryIterator.Next() + switch err.(type) { + case nil: + if len(localHistoryBatch.GetEvents()) > 0 { + lastBatchNodeID = localHistoryBatch.GetEvents()[0].GetEventId() + } + case *serviceerror.NotFound: + default: + return nil, err + } + } + + remoteHistoryIterator := collection.NewPagingIterator(r.getHistoryFromRemotePaginationFn( ctx, remoteClusterName, namespaceName, @@ -839,34 +864,88 @@ func (r *nDCHistoryReplicatorImpl) backfillHistory( workflowID, runID, lastEventID, - lastEventVersion)) - + lastEventVersion), + ) var lastHistoryBatch *commonpb.DataBlob prevTxnID := common.EmptyVersion - for historyIterator.HasNext() { - historyBlob, err := historyIterator.Next() + historyBranch, err := serialization.HistoryBranchFromBlob(branchToken, enumspb.ENCODING_TYPE_PROTO3.String()) + if err != nil { + return nil, err + } + latestBranchID := historyBranch.GetBranchId() + var prevBranchID string + + sortedAncestors := copyAndSortAncestors(historyBranch.GetAncestors()) + sortedAncestorsIdx := 0 + historyBranch.Ancestors = nil + +BackfillLoop: + for remoteHistoryIterator.HasNext() { + historyBlob, err := remoteHistoryIterator.Next() + if err != nil { + return nil, err + } + + if historyBlob.nodeID <= lastBatchNodeID { + // The history batch already in DB. + continue BackfillLoop + } + + if sortedAncestorsIdx < len(sortedAncestors) { + currentAncestor := sortedAncestors[sortedAncestorsIdx] + if historyBlob.nodeID >= currentAncestor.GetEndNodeId() { + // update ancestor + historyBranch.Ancestors = append(historyBranch.Ancestors, currentAncestor) + sortedAncestorsIdx++ + } + if sortedAncestorsIdx < len(sortedAncestors) { + // use ancestor branch id + currentAncestor = sortedAncestors[sortedAncestorsIdx] + historyBranch.BranchId = currentAncestor.GetBranchId() + if historyBlob.nodeID < currentAncestor.GetBeginNodeId() || historyBlob.nodeID >= currentAncestor.GetEndNodeId() { + return nil, serviceerror.NewInternal( + fmt.Sprintf("The backfill history blob node id %d is not in acestoer range [%d, %d]", + historyBlob.nodeID, + currentAncestor.GetBeginNodeId(), + currentAncestor.GetEndNodeId()), + ) + } + } else { + // no more ancestor, use the latest branch ID + historyBranch.BranchId = latestBranchID + } + } + + filteredHistoryBranch, err := serialization.HistoryBranchToBlob(historyBranch) if err != nil { return nil, err } - lastHistoryBatch = historyBlob.rawHistory txnID, err := r.shard.GenerateTaskID() if err != nil { return nil, err } _, err = r.shard.GetExecutionManager().AppendRawHistoryNodes(ctx, &persistence.AppendRawHistoryNodesRequest{ ShardID: r.shard.GetShardID(), - IsNewBranch: prevTxnID == common.EmptyVersion, - BranchToken: branchToken, + IsNewBranch: prevBranchID != historyBranch.BranchId, + BranchToken: filteredHistoryBranch.GetData(), History: historyBlob.rawHistory, PrevTransactionID: prevTxnID, TransactionID: txnID, NodeID: historyBlob.nodeID, + Info: persistence.BuildHistoryGarbageCleanupInfo( + namespaceID.String(), + workflowID, + runID, + ), }) if err != nil { return nil, err } prevTxnID = txnID + prevBranchID = historyBranch.BranchId + lastHistoryBatch = historyBlob.rawHistory } + var lastEventTime *time.Time events, _ := r.historySerializer.DeserializeEvents(lastHistoryBatch) if len(events) > 0 { @@ -875,7 +954,21 @@ func (r *nDCHistoryReplicatorImpl) backfillHistory( return lastEventTime, nil } -func (r *nDCHistoryReplicatorImpl) getHistoryPaginationFn( +func copyAndSortAncestors(input []*persistencespb.HistoryBranchRange) []*persistencespb.HistoryBranchRange { + ans := make([]*persistencespb.HistoryBranchRange, len(input)) + copy(ans, input) + if len(ans) > 0 { + // sort ans based onf EndNodeID so that we can set BeginNodeID + sort.Slice(ans, func(i, j int) bool { return ans[i].GetEndNodeId() < ans[j].GetEndNodeId() }) + ans[0].BeginNodeId = int64(1) + for i := 1; i < len(ans); i++ { + ans[i].BeginNodeId = ans[i-1].GetEndNodeId() + } + } + return ans +} + +func (r *nDCHistoryReplicatorImpl) getHistoryFromRemotePaginationFn( ctx context.Context, remoteClusterName string, namespaceName namespace.Name, @@ -915,3 +1008,30 @@ func (r *nDCHistoryReplicatorImpl) getHistoryPaginationFn( return batches, response.NextPageToken, nil } } + +func (r *nDCHistoryReplicatorImpl) getHistoryFromLocalPaginationFn( + ctx context.Context, + branchToken []byte, + lastEventID int64, +) collection.PaginationFn[*historypb.History] { + + return func(paginationToken []byte) ([]*historypb.History, []byte, error) { + response, err := r.shard.GetExecutionManager().ReadHistoryBranchByBatch(ctx, &persistence.ReadHistoryBranchRequest{ + ShardID: r.shard.GetShardID(), + BranchToken: branchToken, + MinEventID: common.FirstEventID, + MaxEventID: lastEventID + 1, + PageSize: 100, + NextPageToken: paginationToken, + }) + if err != nil { + return nil, nil, err + } + + histories := make([]*historypb.History, 0, len(response.History)) + for _, history := range response.History { + histories = append(histories, history) + } + return histories, response.NextPageToken, nil + } +} diff --git a/service/history/nDCHistoryReplicator_test.go b/service/history/nDCHistoryReplicator_test.go index a8d315f143d..3d4ac1d3339 100644 --- a/service/history/nDCHistoryReplicator_test.go +++ b/service/history/nDCHistoryReplicator_test.go @@ -128,6 +128,12 @@ func (s *nDCHistoryReplicatorSuite) TearDownTest() { func (s *nDCHistoryReplicatorSuite) Test_ApplyWorkflowState_BrandNew() { namespaceID := uuid.New() namespaceName := "namespaceName" + historyBranch, err := serialization.HistoryBranchToBlob(&persistencespb.HistoryBranch{ + TreeId: uuid.New(), + BranchId: uuid.New(), + Ancestors: nil, + }) + s.NoError(err) request := &historyservice.ReplicateWorkflowStateRequest{ WorkflowState: &persistencespb.WorkflowMutableState{ ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ @@ -137,7 +143,7 @@ func (s *nDCHistoryReplicatorSuite) Test_ApplyWorkflowState_BrandNew() { CurrentVersionHistoryIndex: 0, Histories: []*historyspb.VersionHistory{ { - BranchToken: []byte{123}, + BranchToken: historyBranch.GetData(), Items: []*historyspb.VersionHistoryItem{ { EventId: int64(100), @@ -188,14 +194,169 @@ func (s *nDCHistoryReplicatorSuite) Test_ApplyWorkflowState_BrandNew() { &adminservice.GetWorkflowExecutionRawHistoryV2Response{}, nil, ) + s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("test")) + s.mockExecutionManager.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + fakeStartHistory := &historypb.HistoryEvent{ + Attributes: &historypb.HistoryEvent_WorkflowExecutionStartedEventAttributes{ + WorkflowExecutionStartedEventAttributes: &historypb.WorkflowExecutionStartedEventAttributes{}, + }, + } + s.mockEventCache.EXPECT().GetEvent(gomock.Any(), gomock.Any(), common.FirstEventID, gomock.Any()).Return(fakeStartHistory, nil).AnyTimes() + err = s.historyReplicator.ApplyWorkflowState(context.Background(), request) + s.NoError(err) +} + +func (s *nDCHistoryReplicatorSuite) Test_ApplyWorkflowState_Ancestors() { + namespaceID := uuid.New() + namespaceName := "namespaceName" + historyBranch, err := serialization.HistoryBranchToBlob(&persistencespb.HistoryBranch{ + TreeId: uuid.New(), + BranchId: uuid.New(), + Ancestors: []*persistencespb.HistoryBranchRange{ + { + BranchId: uuid.New(), + BeginNodeId: 1, + EndNodeId: 3, + }, + { + BranchId: uuid.New(), + BeginNodeId: 3, + EndNodeId: 4, + }, + }, + }) + s.NoError(err) + request := &historyservice.ReplicateWorkflowStateRequest{ + WorkflowState: &persistencespb.WorkflowMutableState{ + ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ + WorkflowId: s.workflowID, + NamespaceId: namespaceID, + VersionHistories: &historyspb.VersionHistories{ + CurrentVersionHistoryIndex: 0, + Histories: []*historyspb.VersionHistory{ + { + BranchToken: historyBranch.GetData(), + Items: []*historyspb.VersionHistoryItem{ + { + EventId: int64(100), + Version: int64(100), + }, + }, + }, + }, + }, + }, + ExecutionState: &persistencespb.WorkflowExecutionState{ + RunId: s.runID, + State: enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, + Status: enumspb.WORKFLOW_EXECUTION_STATUS_TERMINATED, + }, + }, + RemoteCluster: "test", + } + we := commonpb.WorkflowExecution{ + WorkflowId: s.workflowID, + RunId: s.runID, + } + mockWeCtx := workflow.NewMockContext(s.controller) + s.mockHistoryCache.EXPECT().GetOrCreateWorkflowExecution( + gomock.Any(), + namespace.ID(namespaceID), + we, + workflow.CallerTypeTask, + ).Return(mockWeCtx, workflow.NoopReleaseFn, nil) + mockWeCtx.EXPECT().CreateWorkflowExecution( + gomock.Any(), + gomock.Any(), + persistence.CreateWorkflowModeBrandNew, + "", + int64(0), + gomock.Any(), + gomock.Any(), + []*persistence.WorkflowEvents{}, + ).Return(nil) + s.mockNamespaceCache.EXPECT().GetNamespaceByID(namespace.ID(namespaceID)).Return(namespace.NewNamespaceForTest( + &persistencespb.NamespaceInfo{Name: namespaceName}, + nil, + false, + nil, + int64(100), + ), nil).AnyTimes() + expectedHistory := []*historypb.History{ + { + Events: []*historypb.HistoryEvent{ + { + EventId: 1, + }, + { + EventId: 2, + }, + }, + }, + { + Events: []*historypb.HistoryEvent{ + { + EventId: 3, + }, + }, + }, + { + Events: []*historypb.HistoryEvent{ + { + EventId: 4, + }, + }, + }, + { + Events: []*historypb.HistoryEvent{ + { + EventId: 5, + }, + { + EventId: 6, + }, + }, + }, + } + serializer := serialization.NewSerializer() + var historyBlobs []*commonpb.DataBlob + var nodeIds []int64 + for _, history := range expectedHistory { + blob, err := serializer.SerializeEvents(history.GetEvents(), enumspb.ENCODING_TYPE_PROTO3) + s.NoError(err) + historyBlobs = append(historyBlobs, blob) + nodeIds = append(nodeIds, history.GetEvents()[0].GetEventId()) + } + s.mockRemoteAdminClient.EXPECT().GetWorkflowExecutionRawHistoryV2(gomock.Any(), gomock.Any()).Return( + &adminservice.GetWorkflowExecutionRawHistoryV2Response{ + HistoryBatches: historyBlobs, + HistoryNodeIds: nodeIds, + }, + nil, + ) + s.mockExecutionManager.EXPECT().ReadHistoryBranchByBatch(gomock.Any(), gomock.Any()).Return(&persistence.ReadHistoryBranchByBatchResponse{ + History: []*historypb.History{ + { + Events: []*historypb.HistoryEvent{ + { + EventId: 1, + }, + { + EventId: 2, + }, + }, + }, + }, + }, nil) s.mockExecutionManager.EXPECT().GetCurrentExecution(gomock.Any(), gomock.Any()).Return(nil, serviceerror.NewNotFound("")) + s.mockExecutionManager.EXPECT().AppendRawHistoryNodes(gomock.Any(), gomock.Any()).Return(nil, nil).Times(3) fakeStartHistory := &historypb.HistoryEvent{ Attributes: &historypb.HistoryEvent_WorkflowExecutionStartedEventAttributes{ WorkflowExecutionStartedEventAttributes: &historypb.WorkflowExecutionStartedEventAttributes{}, }, } s.mockEventCache.EXPECT().GetEvent(gomock.Any(), gomock.Any(), common.FirstEventID, gomock.Any()).Return(fakeStartHistory, nil).AnyTimes() - err := s.historyReplicator.ApplyWorkflowState(context.Background(), request) + err = s.historyReplicator.ApplyWorkflowState(context.Background(), request) s.NoError(err) } diff --git a/service/history/replication/ack_manager.go b/service/history/replication/ack_manager.go index 5daf22ca283..0ad4966173d 100644 --- a/service/history/replication/ack_manager.go +++ b/service/history/replication/ack_manager.go @@ -182,6 +182,17 @@ func (p *ackMgrImpl) GetTask( FirstEventID: taskInfo.FirstEventId, NextEventID: taskInfo.NextEventId, }) + case enumsspb.TASK_TYPE_REPLICATION_SYNC_WORKFLOW_STATE: + return p.taskInfoToTask(ctx, &tasks.SyncWorkflowStateTask{ + WorkflowKey: definition.NewWorkflowKey( + taskInfo.GetNamespaceId(), + taskInfo.GetWorkflowId(), + taskInfo.GetRunId(), + ), + VisibilityTimestamp: time.Unix(0, 0), + TaskID: taskInfo.TaskId, + Version: taskInfo.Version, + }) default: return nil, serviceerror.NewInternal(fmt.Sprintf("Unknown replication task type: %v", taskInfo.TaskType)) } diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index e720d609838..c9727a25684 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -47,6 +47,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" + "go.temporal.io/server/common/persistence/versionhistory" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" serviceerrors "go.temporal.io/server/common/serviceerror" @@ -427,6 +428,32 @@ func (p *taskProcessorImpl) convertTaskToDLQTask( }, }, nil + case enumsspb.REPLICATION_TASK_TYPE_SYNC_WORKFLOW_STATE_TASK: + taskAttributes := replicationTask.GetSyncWorkflowStateTaskAttributes() + executionInfo := taskAttributes.GetWorkflowState().GetExecutionInfo() + executionState := taskAttributes.GetWorkflowState().GetExecutionState() + currentVersionHistory, err := versionhistory.GetCurrentVersionHistory(executionInfo.GetVersionHistories()) + if err != nil { + return nil, err + } + lastItem, err := versionhistory.GetLastVersionHistoryItem(currentVersionHistory) + if err != nil { + return nil, err + } + + return &persistence.PutReplicationTaskToDLQRequest{ + ShardID: p.shard.GetShardID(), + SourceClusterName: p.sourceCluster, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: executionInfo.GetNamespaceId(), + WorkflowId: executionInfo.GetWorkflowId(), + RunId: executionState.GetRunId(), + TaskId: replicationTask.GetSourceTaskId(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_WORKFLOW_STATE, + Version: lastItem.GetVersion(), + }, + }, nil + default: return nil, fmt.Errorf("unknown replication task type: %v", replicationTask.TaskType) } diff --git a/service/history/replication/task_processor_test.go b/service/history/replication/task_processor_test.go index b271807d225..5620693b1eb 100644 --- a/service/history/replication/task_processor_test.go +++ b/service/history/replication/task_processor_test.go @@ -50,6 +50,7 @@ import ( "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" + "go.temporal.io/server/common/persistence/versionhistory" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" "go.temporal.io/server/common/resource" @@ -266,6 +267,28 @@ func (s *taskProcessorSuite) TestHandleReplicationDLQTask_SyncActivity() { s.NoError(err) } +func (s *taskProcessorSuite) TestHandleReplicationDLQTask_SyncWorkflowState() { + namespaceID := uuid.NewRandom().String() + workflowID := uuid.New() + runID := uuid.NewRandom().String() + + request := &persistence.PutReplicationTaskToDLQRequest{ + ShardID: s.shardID, + SourceClusterName: cluster.TestAlternativeClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: namespaceID, + WorkflowId: workflowID, + RunId: runID, + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_WORKFLOW_STATE, + Version: 1, + }, + } + + s.mockExecutionManager.EXPECT().PutReplicationTaskToDLQ(gomock.Any(), request).Return(nil) + err := s.replicationTaskProcessor.handleReplicationDLQTask(request) + s.NoError(err) +} + func (s *taskProcessorSuite) TestHandleReplicationDLQTask_History() { namespaceID := uuid.NewRandom().String() workflowID := uuid.New() @@ -319,6 +342,44 @@ func (s *taskProcessorSuite) TestConvertTaskToDLQTask_SyncActivity() { s.Equal(request, dlqTask) } +func (s *taskProcessorSuite) TestConvertTaskToDLQTask_SyncWorkflowState() { + namespaceID := uuid.NewRandom().String() + workflowID := uuid.New() + runID := uuid.NewRandom().String() + task := &replicationspb.ReplicationTask{ + TaskType: enumsspb.REPLICATION_TASK_TYPE_SYNC_WORKFLOW_STATE_TASK, + Attributes: &replicationspb.ReplicationTask_SyncWorkflowStateTaskAttributes{SyncWorkflowStateTaskAttributes: &replicationspb.SyncWorkflowStateTaskAttributes{ + WorkflowState: &persistencespb.WorkflowMutableState{ + ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ + NamespaceId: namespaceID, + WorkflowId: workflowID, + VersionHistories: versionhistory.NewVersionHistories( + versionhistory.NewVersionHistory(nil, []*historyspb.VersionHistoryItem{versionhistory.NewVersionHistoryItem(1, 1)}), + ), + }, + ExecutionState: &persistencespb.WorkflowExecutionState{ + RunId: runID, + }, + }, + }}, + } + request := &persistence.PutReplicationTaskToDLQRequest{ + ShardID: s.shardID, + SourceClusterName: cluster.TestAlternativeClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: namespaceID, + WorkflowId: workflowID, + RunId: runID, + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_WORKFLOW_STATE, + Version: 1, + }, + } + + dlqTask, err := s.replicationTaskProcessor.convertTaskToDLQTask(task) + s.NoError(err) + s.Equal(request, dlqTask) +} + func (s *taskProcessorSuite) TestConvertTaskToDLQTask_History() { namespaceID := uuid.NewRandom().String() workflowID := uuid.New()