Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Optimize Interlocked.CompareExchange use in Task #93953

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,6 @@ public class Task<TResult> : Task
// The value itself, if set.
internal TResult? m_result;

// Extract rarely used helper for a static method in a separate type so that the Func<Task<Task>, Task<TResult>>
// generic instantiations don't contribute to all Task instantiations, but only those where WhenAny is used.
internal static class TaskWhenAnyCast
{
// Delegate used by:
// public static Task<Task<TResult>> WhenAny<TResult>(IEnumerable<Task<TResult>> tasks);
// public static Task<Task<TResult>> WhenAny<TResult>(params Task<TResult>[] tasks);
// Used to "cast" from Task<Task> to Task<Task<TResult>>.
internal static readonly Func<Task<Task>, Task<TResult>> Value = completed => (Task<TResult>)completed.Result;
}

// Construct a promise-style task without any options.
internal Task()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3473,23 +3473,24 @@ private void RunContinuations(object continuationObject) // separated out of Fin
}

// Not a single; it must be a list.
List<object?> continuations = (List<object?>)continuationObject;
List<object?> list = (List<object?>)continuationObject;

//
// Begin processing of continuation list
//

// Wait for any concurrent adds or removes to be retired
lock (continuations) { }
int continuationCount = continuations.Count;
Monitor.Enter(list);
Monitor.Exit(list);
Span<object?> continuations = CollectionsMarshal.AsSpan(list);

// Fire the asynchronous continuations first. However, if we're not able to run any continuations synchronously,
// then we can skip this first pass, since the second pass that tries to run everything synchronously will instead
// run everything asynchronously anyway.
if (canInlineContinuations)
{
bool forceContinuationsAsync = false;
for (int i = 0; i < continuationCount; i++)
for (int i = 0; i < continuations.Length; i++)
{
// For StandardTaskContinuations, we respect the TaskContinuationOptions.ExecuteSynchronously option,
// as the developer needs to explicitly opt-into running the continuation synchronously, and if they do,
Expand Down Expand Up @@ -3543,7 +3544,7 @@ private void RunContinuations(object continuationObject) // separated out of Fin
}

// ... and then fire the synchronous continuations (if there are any).
for (int i = 0; i < continuationCount; i++)
for (int i = 0; i < continuations.Length; i++)
{
object? currentContinuation = continuations[i];
if (currentContinuation == null)
Expand Down Expand Up @@ -4510,62 +4511,79 @@ internal void AddCompletionAction(ITaskCompletionAction action, bool addBeforeOt

// Support method for AddTaskContinuation that takes care of multi-continuation logic.
// Returns true if and only if the continuation was successfully queued.
// THIS METHOD ASSUMES THAT m_continuationObject IS NOT NULL. That case was taken
// care of in the calling method, AddTaskContinuation().
private bool AddTaskContinuationComplex(object tc, bool addBeforeOthers)
{
Debug.Assert(tc != null, "Expected non-null tc object in AddTaskContinuationComplex");

object? oldValue = m_continuationObject;
Debug.Assert(oldValue is not null, "Expected non-null m_continuationObject object");
if (oldValue == s_taskCompletionSentinel)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
return false;
}

// Logic for the case where we were previously storing a single continuation
if ((oldValue != s_taskCompletionSentinel) && (!(oldValue is List<object?>)))
List<object?>? list = oldValue as List<object?>;
if (list is null)
{
// Construct a new TaskContinuation list and CAS it in.
Interlocked.CompareExchange(ref m_continuationObject, new List<object?> { oldValue }, oldValue);
list = new List<object?>();
if (addBeforeOthers)
{
list.Add(tc);
list.Add(oldValue);
}
else
{
list.Add(oldValue);
list.Add(tc);
}

object? expected = oldValue;
oldValue = Interlocked.CompareExchange(ref m_continuationObject, list, expected);
if (oldValue == expected)
{
// We successfully stored the new list with both continuations in it, so we're done.
return true;
}

// We might be racing against another thread converting the single into
// a list, or we might be racing against task completion, so resample "list"
// below.
// a list, or we might be racing against task completion, so recheck for list again.
list = oldValue as List<object?>;
if (list is null)
{
Debug.Assert(oldValue == s_taskCompletionSentinel, "Expected m_continuationObject to be list or sentinel");
return false;
}
}

// m_continuationObject is guaranteed at this point to be either a List or
// s_taskCompletionSentinel.
List<object?>? list = m_continuationObject as List<object?>;
Debug.Assert((list != null) || (m_continuationObject == s_taskCompletionSentinel),
"Expected m_continuationObject to be list or sentinel");

// If list is null, it can only mean that s_taskCompletionSentinel has been exchanged
// into m_continuationObject. Thus, the task has completed and we should return false
// from this method, as we will not be queuing up the continuation.
if (list != null)
lock (list)
{
lock (list)
// It is possible for the task to complete right after we snap the copy of
// the list. If so, then return false without queuing the continuation.
if (m_continuationObject == s_taskCompletionSentinel)
{
// It is possible for the task to complete right after we snap the copy of
// the list. If so, then fall through and return false without queuing the
// continuation.
if (m_continuationObject != s_taskCompletionSentinel)
{
// Before growing the list we remove possible null entries that are the
// result from RemoveContinuations()
if (list.Count == list.Capacity)
{
list.RemoveAll(l => l == null);
}
return false;
}

if (addBeforeOthers)
list.Insert(0, tc);
else
list.Add(tc);
// Before growing the list we remove possible null entries that are the
// result from RemoveContinuations()
if (list.Count == list.Capacity)
{
list.RemoveAll(l => l == null);
}

return true; // continuation successfully queued, so return true.
}
if (addBeforeOthers)
{
list.Insert(0, tc);
}
else
{
list.Add(tc);
}
}

// We didn't succeed in queuing the continuation, so return false.
return false;
return true; // continuation successfully queued, so return true.
}

// Record a continuation task or action.
Expand Down Expand Up @@ -4603,12 +4621,15 @@ internal void RemoveContinuation(object continuationObject) // could be TaskCont
{
// This is not a list. If we have a single object (the one we want to remove) we try to replace it with an empty list.
// Note we cannot go back to a null state, since it will mess up the AddTaskContinuation logic.
if (Interlocked.CompareExchange(ref m_continuationObject, new List<object?>(), continuationObject) != continuationObject)
continuationsLocalRef = Interlocked.CompareExchange(ref m_continuationObject, new List<object?>(), continuationObject);
if (continuationsLocalRef != continuationObject)
{
// If we fail it means that either AddContinuationComplex won the race condition and m_continuationObject is now a List
// that contains the element we want to remove. Or FinishContinuations set the s_taskCompletionSentinel.
// So we should try to get a list one more time
continuationsLocalListRef = m_continuationObject as List<object?>;
// So we should try to get a list one more time and if it's null then there is nothing else to do.
continuationsLocalListRef = continuationsLocalRef as List<object?>;
if (continuationsLocalListRef is null)
return;
}
else
{
Expand All @@ -4617,24 +4638,20 @@ internal void RemoveContinuation(object continuationObject) // could be TaskCont
}
}

// if continuationsLocalRef == null it means s_taskCompletionSentinel has been set already and there is nothing else to do.
if (continuationsLocalListRef != null)
lock (continuationsLocalListRef)
{
lock (continuationsLocalListRef)
{
// There is a small chance that this task completed since we took a local snapshot into
// continuationsLocalRef. In that case, just return; we don't want to be manipulating the
// continuation list as it is being processed.
if (m_continuationObject == s_taskCompletionSentinel) return;
// There is a small chance that this task completed since we took a local snapshot into
// continuationsLocalRef. In that case, just return; we don't want to be manipulating the
// continuation list as it is being processed.
if (m_continuationObject == s_taskCompletionSentinel) return;

// Find continuationObject in the continuation list
int index = continuationsLocalListRef.IndexOf(continuationObject);
// Find continuationObject in the continuation list
int index = continuationsLocalListRef.IndexOf(continuationObject);

if (index >= 0)
{
// null out that TaskContinuation entry, which will be interpreted as "to be cleaned up"
continuationsLocalListRef[index] = null;
}
if (index >= 0)
{
// null out that TaskContinuation entry, which will be interpreted as "to be cleaned up"
continuationsLocalListRef[index] = null;
}
}
}
Expand Down Expand Up @@ -5902,14 +5919,14 @@ internal static Task WhenAll(ReadOnlySpan<Task> tasks) => // TODO https://github
/// <summary>A Task that gets completed when all of its constituent tasks complete.</summary>
private sealed class WhenAllPromise : Task, ITaskCompletionAction
{
/// <summary>Either a single faulted/canceled task, or a list of faulted/canceled tasks.</summary>
private object? _failedOrCanceled;
/// <summary>The number of tasks remaining to complete.</summary>
private int _remainingToComplete;

internal WhenAllPromise(ReadOnlySpan<Task> tasks)
{
Debug.Assert(tasks.Length != 0, "Expected a non-zero length task array");
Debug.Assert(m_stateObject is null, "Expected to be able to use the state object field for faulted/canceled tasks.");
m_stateFlags |= (int)InternalTaskOptions.HiddenState;

// Throw if any of the provided tasks is null. This is best effort to inform the caller
// they've made a mistake. If between the time we check for nulls and the time we hook
Expand Down Expand Up @@ -5966,16 +5983,14 @@ public void Invoke(Task? completedTask)
if (!completedTask.IsCompletedSuccessfully)
{
// Try to store the completed task as the first that's failed or faulted.
if (Interlocked.CompareExchange(ref _failedOrCanceled, completedTask, null) != null)
object? failedOrCanceled = Interlocked.CompareExchange(ref m_stateObject, completedTask, null);
if (failedOrCanceled != null)
{
// There was already something there.
while (true)
{
object? failedOrCanceled = _failedOrCanceled;
Debug.Assert(failedOrCanceled is not null);

// If it was a list, add it to the list.
if (_failedOrCanceled is List<Task> list)
if (failedOrCanceled is List<Task> list)
{
lock (list)
{
Expand All @@ -5986,13 +6001,15 @@ public void Invoke(Task? completedTask)

// Otherwise, it was a Task. Create a new list containing that task and this one, and store it in.
Debug.Assert(failedOrCanceled is Task, $"Expected Task, got {failedOrCanceled}");
if (Interlocked.CompareExchange(ref _failedOrCanceled, new List<Task> { (Task)failedOrCanceled, completedTask }, failedOrCanceled) == failedOrCanceled)
Task first = (Task)failedOrCanceled;
failedOrCanceled = Interlocked.CompareExchange(ref m_stateObject, new List<Task> { first, completedTask }, first);
if (failedOrCanceled == first)
{
break;
}

// We lost the race, which means we should loop around one more time and it'll be a list.
Debug.Assert(_failedOrCanceled is List<Task>);
Debug.Assert(failedOrCanceled is List<Task>);
}
}
}
Expand All @@ -6001,7 +6018,7 @@ public void Invoke(Task? completedTask)
// Decrement the count, and only continue to complete the promise if we're the last one.
if (Interlocked.Decrement(ref _remainingToComplete) == 0)
{
object? failedOrCanceled = _failedOrCanceled;
object? failedOrCanceled = m_stateObject;
if (failedOrCanceled is null)
{
if (TplEventSource.Log.IsEnabled())
Expand Down
Loading