Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Move getTaskQueueManager call into redirect methods #4759

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions service/matching/matching_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,7 @@ func (e *matchingEngineImpl) AddWorkflowTask(
// We don't need the userDataChanged channel here because:
// - if we sync match or sticky worker unavailable, we're done
// - if we spool to db, we'll re-resolve when it comes out of the db
taskQueue, _, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, addRequest.VersionDirective)
if err != nil {
return false, err
}

tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true)
tqm, _, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, addRequest.VersionDirective)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -393,12 +388,7 @@ func (e *matchingEngineImpl) AddActivityTask(
// We don't need the userDataChanged channel here because:
// - if we sync match, we're done
// - if we spool to db, we'll re-resolve when it comes out of the db
taskQueue, _, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, addRequest.VersionDirective)
if err != nil {
return false, err
}

tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true)
tqm, _, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, addRequest.VersionDirective)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -443,16 +433,15 @@ func (e *matchingEngineImpl) DispatchSpooledTask(
unversionedOrigTaskQueue := newTaskQueueIDWithVersionSet(origTaskQueue, "")
// Redirect and re-resolve if we're blocked in matcher and user data changes.
for {
// If normal queue: always load the base tqm to get versioning data.
// If sticky queue: sticky is not versioned, so if we got here (by taskReader calling this),
// the queue is already loaded.
// So we can always use true here.
baseTqm, err := e.getTaskQueueManager(ctx, unversionedOrigTaskQueue, stickyInfo, true)
if err != nil {
return err
}
taskQueue, userDataChanged, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, directive)
if err != nil {
return err
}
sticky := stickyInfo.kind == enumspb.TASK_QUEUE_KIND_STICKY
tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, !sticky)
tqm, userDataChanged, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, directive)
if err != nil {
return err
}
Expand Down Expand Up @@ -692,18 +681,13 @@ func (e *matchingEngineImpl) QueryWorkflow(

// We don't need the userDataChanged channel here because we either do this sync (local or remote)
// or fail with a relatively short timeout.
taskQueue, _, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, queryRequest.VersionDirective)
tqm, _, err := baseTqm.RedirectToVersionedQueueForAdd(ctx, queryRequest.VersionDirective)
if err != nil {
return nil, err
} else if taskQueue.VersionSet() == dlqVersionSet {
} else if tqm.QueueID().VersionSet() == dlqVersionSet {
return nil, serviceerror.NewFailedPrecondition("Operations on versioned workflows are disabled")
}

tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, !sticky)
if err != nil {
return nil, err
}

taskID := uuid.New()
resp, err := tqm.DispatchQueryTask(ctx, taskID, queryRequest)

Expand Down Expand Up @@ -1204,18 +1188,14 @@ func (e *matchingEngineImpl) getTask(
return nil, err
}

taskQueue, err := baseTqm.RedirectToVersionedQueueForPoll(pollMetadata.workerVersionCapabilities)
tqm, err := baseTqm.RedirectToVersionedQueueForPoll(ctx, pollMetadata.workerVersionCapabilities)
if err != nil {
if errors.Is(err, errUserDataDisabled) {
// Rewrite to nicer error message
err = serviceerror.NewFailedPrecondition("Operations on versioned workflows are disabled")
}
return nil, err
}
tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true)
if err != nil {
return nil, err
}

// We need to set a shorter timeout than the original ctx; otherwise, by the time ctx deadline is
// reached, instead of emptyTask, context timeout error is returned to the frontend by the rpc stack,
Expand Down
41 changes: 28 additions & 13 deletions service/matching/task_queue_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ type (
QueueID() *taskQueueID
TaskQueueKind() enumspb.TaskQueueKind
LongPollExpirationInterval() time.Duration
RedirectToVersionedQueueForAdd(context.Context, *taskqueuespb.TaskVersionDirective) (*taskQueueID, chan struct{}, error)
RedirectToVersionedQueueForPoll(*commonpb.WorkerVersionCapabilities) (*taskQueueID, error)
RedirectToVersionedQueueForAdd(context.Context, *taskqueuespb.TaskVersionDirective) (taskQueueManager, chan struct{}, error)
RedirectToVersionedQueueForPoll(context.Context, *commonpb.WorkerVersionCapabilities) (taskQueueManager, error)
}

// Single task queue in memory state
Expand Down Expand Up @@ -765,11 +765,11 @@ func (c *taskQueueManagerImpl) LongPollExpirationInterval() time.Duration {
return c.config.LongPollExpirationInterval()
}

func (c *taskQueueManagerImpl) RedirectToVersionedQueueForPoll(caps *commonpb.WorkerVersionCapabilities) (*taskQueueID, error) {
func (c *taskQueueManagerImpl) RedirectToVersionedQueueForPoll(ctx context.Context, caps *commonpb.WorkerVersionCapabilities) (taskQueueManager, error) {
if !caps.GetUseVersioning() {
// Either this task queue is versioned, or there are still some workflows running on
// the "unversioned" set.
return c.taskQueueID, nil
return c, nil
}
// We don't need the userDataChanged channel here because polls have a timeout and the
// client will retry, so if we're blocked on the wrong matcher it'll just take one poll
Expand All @@ -789,7 +789,7 @@ func (c *taskQueueManagerImpl) RedirectToVersionedQueueForPoll(caps *commonpb.Wo
if unknownBuild {
c.recordUnknownBuildPoll(caps.BuildId)
}
return c.taskQueueID, nil
return c, nil
}

versionSet, unknownBuild, err := lookupVersionSetForPoll(data, caps)
Expand All @@ -799,10 +799,12 @@ func (c *taskQueueManagerImpl) RedirectToVersionedQueueForPoll(caps *commonpb.Wo
if unknownBuild {
c.recordUnknownBuildPoll(caps.BuildId)
}
return newTaskQueueIDWithVersionSet(c.taskQueueID, versionSet), nil

newId := newTaskQueueIDWithVersionSet(c.taskQueueID, versionSet)
return c.engine.getTaskQueueManager(ctx, newId, c.stickyInfo, true)
}

func (c *taskQueueManagerImpl) RedirectToVersionedQueueForAdd(ctx context.Context, directive *taskqueuespb.TaskVersionDirective) (*taskQueueID, chan struct{}, error) {
func (c *taskQueueManagerImpl) RedirectToVersionedQueueForAdd(ctx context.Context, directive *taskqueuespb.TaskVersionDirective) (taskQueueManager, chan struct{}, error) {
var buildId string
switch dir := directive.GetValue().(type) {
case *taskqueuespb.TaskVersionDirective_UseDefault:
Expand All @@ -811,7 +813,7 @@ func (c *taskQueueManagerImpl) RedirectToVersionedQueueForAdd(ctx context.Contex
buildId = dir.BuildId
default:
// Unversioned task, leave on unversioned queue.
return c.taskQueueID, nil, nil
return c, nil, nil
}

// Have to look up versioning data.
Expand All @@ -820,14 +822,21 @@ func (c *taskQueueManagerImpl) RedirectToVersionedQueueForAdd(ctx context.Contex
if errors.Is(err, errUserDataDisabled) {
// When user data disabled, send "default" tasks to unversioned queue.
if buildId == "" {
return c.taskQueueID, userDataChanged, nil
return c, userDataChanged, nil
}
// Send versioned sticky back to regular queue so they can go in the dlq.
if c.kind == enumspb.TASK_QUEUE_KIND_STICKY {
return nil, nil, serviceerrors.NewStickyWorkerUnavailable()
}
// Send versioned tasks to dlq.
return newTaskQueueIDWithVersionSet(c.taskQueueID, dlqVersionSet), userDataChanged, nil
newId := newTaskQueueIDWithVersionSet(c.taskQueueID, dlqVersionSet)
// If we're called by QueryWorkflow, then we technically don't need to load the dlq
// tqm here. But it's not a big deal if we do.
tqm, err := c.engine.getTaskQueueManager(ctx, newId, c.stickyInfo, true)
if err != nil {
return nil, nil, err
}
return tqm, userDataChanged, nil
}
return nil, nil, err
}
Expand All @@ -844,13 +853,13 @@ func (c *taskQueueManagerImpl) RedirectToVersionedQueueForAdd(ctx context.Contex
// Don't bother persisting the unknown build id in this case: sticky tasks have a
// short timeout, so it doesn't matter if they get lost.
}
return c.taskQueueID, userDataChanged, nil
return c, userDataChanged, nil
}

versionSet, unknownBuild, err := lookupVersionSetForAdd(data, buildId)
if err == errEmptyVersioningData { // nolint:goerr113
// default was requested for an unversioned queue
return c.taskQueueID, userDataChanged, nil
return c, userDataChanged, nil
} else if err != nil {
return nil, nil, err
}
Expand All @@ -868,7 +877,13 @@ func (c *taskQueueManagerImpl) RedirectToVersionedQueueForAdd(ctx context.Contex
return nil, nil, err
}
}
return newTaskQueueIDWithVersionSet(c.taskQueueID, versionSet), userDataChanged, nil

newId := newTaskQueueIDWithVersionSet(c.taskQueueID, versionSet)
tqm, err := c.engine.getTaskQueueManager(ctx, newId, c.stickyInfo, true)
if err != nil {
return nil, nil, err
}
return tqm, userDataChanged, nil
}

func (c *taskQueueManagerImpl) recordUnknownBuildPoll(buildId string) {
Expand Down