Skip to content

Commit

Permalink
Merge pull request #66767 from ryanbrandenburg/dev/toddgrun/RequestCo…
Browse files Browse the repository at this point in the history
…ncurrency

Dev/toddgrun/request concurrency
  • Loading branch information
ToddGrun authored Feb 13, 2023
2 parents c768df7 + 2bde77f commit 5cf6350
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Elfie.Diagnostics;
Expand All @@ -13,6 +15,7 @@
using StreamJsonRpc;
using Xunit;
using static Microsoft.CommonLanguageServerProtocol.Framework.UnitTests.HandlerProviderTests;
using static Microsoft.CommonLanguageServerProtocol.Framework.UnitTests.RequestExecutionQueueTests;

namespace Microsoft.CommonLanguageServerProtocol.Framework.UnitTests;

Expand All @@ -31,14 +34,29 @@ protected override ILspServices ConstructLspServices()
}

private const string MethodName = "SomeMethod";
private const string CancellingMethod = "CancellingMethod";
private const string CompletingMethod = "CompletingMethod";
private const string MutatingMethod = "MutatingMethod";

private static RequestExecutionQueue<TestRequestContext> GetRequestExecutionQueue(IMethodHandler? methodHandler = null)
private static RequestExecutionQueue<TestRequestContext> GetRequestExecutionQueue(bool cancelInProgressWorkUponMutatingRequest, params IMethodHandler[] methodHandlers)
{
var handlerProvider = new Mock<IHandlerProvider>(MockBehavior.Strict);
var handler = methodHandler ?? GetTestMethodHandler();
handlerProvider.Setup(h => h.GetMethodHandler(MethodName, TestMethodHandler.RequestType, TestMethodHandler.ResponseType)).Returns(handler);
if (methodHandlers.Length == 0)
{
var handler = GetTestMethodHandler();
handlerProvider.Setup(h => h.GetMethodHandler(MethodName, TestMethodHandler.RequestType, TestMethodHandler.ResponseType)).Returns(handler);
}

foreach (var methodHandler in methodHandlers)
{
var methodType = methodHandler.GetType();
var methodAttribute = methodType.GetCustomAttribute<LanguageServerEndpointAttribute>();
var method = methodAttribute.Method;

var executionQueue = new RequestExecutionQueue<TestRequestContext>(new MockServer(), NoOpLspLogger.Instance, handlerProvider.Object);
handlerProvider.Setup(h => h.GetMethodHandler(method, typeof(int), typeof(string))).Returns(methodHandler);
}

var executionQueue = new TestRequestExecutionQueue(new MockServer(), NoOpLspLogger.Instance, handlerProvider.Object, cancelInProgressWorkUponMutatingRequest);
executionQueue.Start();

return executionQueue;
Expand All @@ -65,19 +83,45 @@ private static TestMethodHandler GetTestMethodHandler()
[Fact]
public async Task ExecuteAsync_ThrowCompletes()
{
// Arrange
var throwingHandler = new ThrowingHandler();
var requestExecutionQueue = GetRequestExecutionQueue(throwingHandler);
var request = 1;
var requestExecutionQueue = GetRequestExecutionQueue(false, throwingHandler);
var lspServices = GetLspServices();

await Assert.ThrowsAsync<NotImplementedException>(() => requestExecutionQueue.ExecuteAsync<int, string>(request, MethodName, lspServices, CancellationToken.None));
// Act & Assert
await Assert.ThrowsAsync<NotImplementedException>(() => requestExecutionQueue.ExecuteAsync<int, string>(1, MethodName, lspServices, CancellationToken.None));
}

[Fact]
public async Task ExecuteAsync_WithCancelInProgressWork_CancelsInProgressWorkWhenMutatingRequestArrives()
{
// Let's try it a bunch of times to try to find timing issues.
for (var i = 0; i < 20; i++)
{
// Arrange
var mutatingHandler = new MutatingHandler();
var cancellingHandler = new CancellingHandler();
var completingHandler = new CompletingHandler();
var requestExecutionQueue = GetRequestExecutionQueue(cancelInProgressWorkUponMutatingRequest: true, methodHandlers: new IMethodHandler[] { cancellingHandler, completingHandler, mutatingHandler });
var lspServices = GetLspServices();

var cancellingRequestCancellationToken = new CancellationToken();
var completingRequestCancellationToken = new CancellationToken();

var _ = requestExecutionQueue.ExecuteAsync<int, string>(1, CancellingMethod, lspServices, cancellingRequestCancellationToken);
var _1 = requestExecutionQueue.ExecuteAsync<int, string>(1, CompletingMethod, lspServices, completingRequestCancellationToken);

// Act & Assert
// A Debug.Assert would throw if the tasks hadn't completed when the mutating request is called.
await requestExecutionQueue.ExecuteAsync<int, string>(1, MutatingMethod, lspServices, CancellationToken.None);
}
}

[Fact]
public async Task Dispose_MultipleTimes_Succeeds()
{
// Arrange
var requestExecutionQueue = GetRequestExecutionQueue();
var requestExecutionQueue = GetRequestExecutionQueue(false);

// Act
await requestExecutionQueue.DisposeAsync();
Expand All @@ -86,20 +130,10 @@ public async Task Dispose_MultipleTimes_Succeeds()
// Assert, it didn't fail
}

public class ThrowingHandler : IRequestHandler<int, string, TestRequestContext>
{
public bool MutatesSolutionState => false;

public Task<string> HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
}

[Fact]
public async Task ExecuteAsync_CompletesTask()
{
var requestExecutionQueue = GetRequestExecutionQueue();
var requestExecutionQueue = GetRequestExecutionQueue(false);
var request = 1;
var lspServices = GetLspServices();

Expand All @@ -111,7 +145,7 @@ public async Task ExecuteAsync_CompletesTask()
[Fact]
public async Task Queue_DrainsOnShutdown()
{
var requestExecutionQueue = GetRequestExecutionQueue();
var requestExecutionQueue = GetRequestExecutionQueue(false);
var request = 1;
var lspServices = GetLspServices();

Expand All @@ -124,7 +158,75 @@ public async Task Queue_DrainsOnShutdown()
Assert.True(task2.IsCompleted);
}

private class TestResponse
private class TestRequestExecutionQueue : RequestExecutionQueue<TestRequestContext>
{
private readonly bool _cancelInProgressWorkUponMutatingRequest;

public TestRequestExecutionQueue(AbstractLanguageServer<TestRequestContext> languageServer, ILspLogger logger, IHandlerProvider handlerProvider, bool cancelInProgressWorkUponMutatingRequest)
: base(languageServer, logger, handlerProvider)
{
_cancelInProgressWorkUponMutatingRequest = cancelInProgressWorkUponMutatingRequest;
}

protected override bool CancelInProgressWorkUponMutatingRequest => _cancelInProgressWorkUponMutatingRequest;
}

[LanguageServerEndpoint(MutatingMethod)]
public class MutatingHandler : IRequestHandler<int, string, TestRequestContext>
{
public MutatingHandler()
{
}

public bool MutatesSolutionState => true;

public Task<string> HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken)
{
return Task.FromResult(string.Empty);
}
}

[LanguageServerEndpoint(CompletingMethod)]
public class CompletingHandler : IRequestHandler<int, string, TestRequestContext>
{
public bool MutatesSolutionState => false;

public async Task<string> HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken)
{
while (true)
{
if (cancellationToken.IsCancellationRequested)
{
return "I completed!";
}
await Task.Delay(100);
}
}
}

[LanguageServerEndpoint(CancellingMethod)]
public class CancellingHandler : IRequestHandler<int, string, TestRequestContext>
{
public bool MutatesSolutionState => false;

public async Task<string> HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken)
{
while (true)
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Delay(100);
}
}
}

[LanguageServerEndpoint(MethodName)]
public class ThrowingHandler : IRequestHandler<int, string, TestRequestContext>
{
public bool MutatesSolutionState => false;

public Task<string> HandleRequestAsync(int request, TestRequestContext context, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.Threading;
using System.Collections.Immutable;

namespace Microsoft.CommonLanguageServerProtocol.Framework;

/// <summary>
/// Coordinates the exectution of LSP messages to ensure correct results are sent back.
/// Coordinates the execution of LSP messages to ensure correct results are sent back.
/// </summary>
/// <remarks>
/// <para>
Expand All @@ -21,7 +22,7 @@ namespace Microsoft.CommonLanguageServerProtocol.Framework;
/// (via textDocument/didChange for example).
/// </para>
/// <para>
/// This class acheives this by distinguishing between mutating and non-mutating requests, and ensuring that
/// This class achieves this by distinguishing between mutating and non-mutating requests, and ensuring that
/// when a mutating request comes in, its processing blocks all subsequent requests. As each request comes in
/// it is added to a queue, and a queue item will not be retrieved while a mutating request is running. Before
/// any request is handled the solution state is created by merging workspace solution state, which could have
Expand Down Expand Up @@ -89,6 +90,19 @@ protected IMethodHandler GetMethodHandler<TRequest, TResponse>(string methodName
return handler;
}

/// <summary>
/// Indicates this queue requires in-progress work to be cancelled before servicing
/// a mutating request.
/// </summary>
/// <remarks>
/// This was added for WebTools consumption as they aren't resilient to
/// incomplete requests continuing execution during didChange notifications. As their
/// parse trees are mutable, a didChange notification requires all previous requests
/// to be completed before processing. This is similar to the O#
/// WithContentModifiedSupport(false) behavior.
/// </remarks>
protected virtual bool CancelInProgressWorkUponMutatingRequest => false;

/// <summary>
/// Queues a request to be handled by the specified handler, with mutating requests blocking subsequent requests
/// from starting until the mutation is complete.
Expand Down Expand Up @@ -156,6 +170,8 @@ private async Task ProcessQueueAsync()
ILspServices? lspServices = null;
try
{
var concurrentlyExecutingTasks = new ConcurrentDictionary<Task, CancellationTokenSource>();

while (!_cancelSource.IsCancellationRequested)
{
// First attempt to de-queue the work item in its own try-catch.
Expand All @@ -175,9 +191,27 @@ private async Task ProcessQueueAsync()
try
{
var (work, activityId, cancellationToken) = queueItem;
CancellationTokenSource? currentWorkCts = null;
lspServices = work.LspServices;

var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken, cancellationToken);
if (CancelInProgressWorkUponMutatingRequest)
{
try
{
currentWorkCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken, cancellationToken);
}
catch (ObjectDisposedException)
{
// Explicitly ignore this exception as this can occur during the CreateLinkTokenSource call, and means one of the
// linked cancellationTokens has been cancelled. If this occurs, skip to the next loop iteration as this
// queueItem requires no processing
continue;
}

// Use the linked cancellation token so it's task can be cancelled if necessary during a mutating request
// on a queue that specifies CancelInProgressWorkUponMutatingRequest
cancellationToken = currentWorkCts.Token;
}

// Restore our activity id so that logging/tracking works across asynchronous calls.
Trace.CorrelationManager.ActivityId = activityId;
Expand All @@ -186,23 +220,60 @@ private async Task ProcessQueueAsync()
var context = await work.CreateRequestContextAsync(cancellationToken).ConfigureAwait(false);
if (work.MutatesServerState)
{
if (CancelInProgressWorkUponMutatingRequest)
{
// Cancel all concurrently executing tasks
var concurrentlyExecutingTasksArray = concurrentlyExecutingTasks.ToArray();
for (var i = 0; i < concurrentlyExecutingTasksArray.Length; i++)
{
concurrentlyExecutingTasksArray[i].Value.Cancel();
}

// wait for all pending tasks to complete their cancellation, ignoring any exceptions
await Task.WhenAll(concurrentlyExecutingTasksArray.Select(kvp => kvp.Key)).NoThrowAwaitableInternal(captureContext: false);
}

Debug.Assert(!concurrentlyExecutingTasks.Any(), "The tasks should have all been drained before continuing");
// Mutating requests block other requests from starting to ensure an up to date snapshot is used.
// Since we're explicitly awaiting exceptions to mutating requests will bubble up here.
await WrapStartRequestTaskAsync(work.StartRequestAsync(context, cancellationToken), rethrowExceptions: true).ConfigureAwait(false);
}
else
{
// Non mutating are fire-and-forget because they are by definition readonly. Any errors
// Non mutating are fire-and-forget because they are by definition read-only. Any errors
// will be sent back to the client but they can also be captured via HandleNonMutatingRequestError,
// though these errors don't put us into a bad state as far as the rest of the queue goes.
// Furthermore we use Task.Run here to protect ourselves against synchronous execution of work
// blocking the request queue for longer periods of time (it enforces parallelizabilty).
_ = WrapStartRequestTaskAsync(Task.Run(() => work.StartRequestAsync(context, cancellationToken), cancellationToken), rethrowExceptions: false);
// blocking the request queue for longer periods of time (it enforces parallelizability).
var currentWorkTask = WrapStartRequestTaskAsync(Task.Run(() => work.StartRequestAsync(context, cancellationToken), cancellationToken), rethrowExceptions: false);

if (CancelInProgressWorkUponMutatingRequest)
{
if (currentWorkCts is null)
{
throw new InvalidOperationException($"unexpected null value for {nameof(currentWorkCts)}");
}

if (!concurrentlyExecutingTasks.TryAdd(currentWorkTask, currentWorkCts))
{
throw new InvalidOperationException($"unable to add {nameof(currentWorkTask)} into {nameof(concurrentlyExecutingTasks)}");
}

_ = currentWorkTask.ContinueWith(t =>
{
if (!concurrentlyExecutingTasks.TryRemove(t, out var concurrentlyExecutingTaskCts))
{
throw new InvalidOperationException($"unexpected failure to remove task from {nameof(concurrentlyExecutingTasks)}");
}

concurrentlyExecutingTaskCts.Dispose();
}, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
}
}
}
catch (OperationCanceledException ex) when (ex.CancellationToken == queueItem.cancellationToken)
catch (OperationCanceledException)
{
// Explicitly ignore this exception as cancellation occured as a result of our linked cancellation token.
// Explicitly ignore this exception as cancellation occurred as a result of our linked cancellation token.
// This means either the queue is shutting down or the request itself was cancelled.
// 1. If the queue is shutting down, then while loop will exit before the next iteration since it checks for cancellation.
// 2. Request cancellations are normal so no need to report anything there.
Expand All @@ -227,7 +298,7 @@ private async Task ProcessQueueAsync()
}

/// <summary>
/// Provides an extensiblity point to log or otherwise inspect errors thrown from non-mutating requests,
/// Provides an extensibility point to log or otherwise inspect errors thrown from non-mutating requests,
/// which would otherwise be lost to the fire-and-forget task in the queue.
/// </summary>
/// <param name="nonMutatingRequestTask">The task to be inspected.</param>
Expand Down
Loading

0 comments on commit 5cf6350

Please # to comment.