diff --git a/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/Retry.java b/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/Retry.java index 12332dbb2..064dec5c3 100644 --- a/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/Retry.java +++ b/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/Retry.java @@ -4,11 +4,10 @@ package com.microsoft.bot.connector.authentication; -import com.microsoft.bot.connector.ExecutorFactory; - import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Supplier; @@ -24,7 +23,7 @@ private Retry() { /** * Runs a task with retry. - * + * * @param task The task to run. * @param retryExceptionHandler Called when an exception happens. * @param <TResult> The type of the result. @@ -36,41 +35,50 @@ public static <TResult> CompletableFuture<TResult> run( Supplier<CompletableFuture<TResult>> task, BiFunction<RuntimeException, Integer, RetryParams> retryExceptionHandler ) { + return runInternal(task, retryExceptionHandler, 1, new ArrayList<>()); + } - CompletableFuture<TResult> result = new CompletableFuture<>(); + private static <TResult> CompletableFuture<TResult> runInternal( + Supplier<CompletableFuture<TResult>> task, + BiFunction<RuntimeException, Integer, RetryParams> retryExceptionHandler, + final Integer retryCount, + final List<Throwable> exceptions + ) { + AtomicReference<RetryParams> retry = new AtomicReference<>(); - ExecutorFactory.getExecutor().execute(() -> { - RetryParams retry = RetryParams.stopRetrying(); - List<Throwable> exceptions = new ArrayList<>(); - int currentRetryCount = 0; + return task.get() + .exceptionally((t) -> { + exceptions.add(t); + retry.set(retryExceptionHandler.apply(new RetryException(t), retryCount)); + return null; + }) + .thenCompose(taskResult -> { + CompletableFuture<TResult> result = new CompletableFuture<>(); - do { - try { - result.complete(task.get().join()); - } catch (Throwable t) { - exceptions.add(t); - retry = retryExceptionHandler.apply(new RetryException(t), currentRetryCount); + if (retry.get() == null) { + result.complete(taskResult); + return result; } - if (retry.getShouldRetry()) { - currentRetryCount++; + if (retry.get().getShouldRetry()) { try { - Thread.sleep(withBackoff(retry.getRetryAfter(), currentRetryCount)); + Thread.sleep(withBackOff(retry.get().getRetryAfter(), retryCount)); } catch (InterruptedException e) { throw new RetryException(e); } + + return runInternal(task, retryExceptionHandler, retryCount + 1, exceptions); } - } while (retry.getShouldRetry()); - result.completeExceptionally(new RetryException("Exceeded retry count", exceptions)); - }); + result.completeExceptionally(new RetryException("Exceeded retry count", exceptions)); - return result; + return result; + }); } private static final double BACKOFF_MULTIPLIER = 1.1; - private static long withBackoff(long delay, int retryCount) { + private static long withBackOff(long delay, int retryCount) { double result = delay * Math.pow(BACKOFF_MULTIPLIER, retryCount - 1); return (long) Math.min(result, Long.MAX_VALUE); } diff --git a/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/RetryParams.java b/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/RetryParams.java index 975132bb2..899c5225a 100644 --- a/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/RetryParams.java +++ b/libraries/bot-connector/src/main/java/com/microsoft/bot/connector/authentication/RetryParams.java @@ -10,7 +10,7 @@ * State for Retry. */ public class RetryParams { - private static final int MAX_RETRIES = 10; + public static final int MAX_RETRIES = 10; private static final Duration MAX_DELAY = Duration.ofSeconds(10); private static final Duration DEFAULT_BACKOFF_TIME = Duration.ofMillis(50); @@ -23,11 +23,9 @@ public class RetryParams { * @return A RetryParams that returns false for {@link #getShouldRetry()}. */ public static RetryParams stopRetrying() { - return new RetryParams() { - { - setShouldRetry(false); - } - }; + return new RetryParams() {{ + setShouldRetry(false); + }}; } /** diff --git a/libraries/bot-connector/src/test/java/com/microsoft/bot/connector/RetryTests.java b/libraries/bot-connector/src/test/java/com/microsoft/bot/connector/RetryTests.java index a4ae0b77e..2eb8d711b 100644 --- a/libraries/bot-connector/src/test/java/com/microsoft/bot/connector/RetryTests.java +++ b/libraries/bot-connector/src/test/java/com/microsoft/bot/connector/RetryTests.java @@ -16,7 +16,7 @@ public void Retry_NoRetryWhenTaskSucceeds() { exceptionToThrow = null; }}; - String result = Retry.run(() -> + Retry.run(() -> faultyClass.faultyTask(), ((e, integer) -> faultyClass.exceptionHandler(e, integer))) .join(); @@ -32,8 +32,8 @@ public void Retry_RetryThenSucceed() { triesUntilSuccess = 3; }}; - String result = Retry.run(() -> - faultyClass.faultyTask(), + Retry.run(() -> + faultyClass.faultyTask(), ((e, integer) -> faultyClass.exceptionHandler(e, integer))) .join(); @@ -50,11 +50,14 @@ public void Retry_RetryUntilFailure() { try { Retry.run(() -> - faultyClass.faultyTask(), + faultyClass.faultyTask(), ((e, integer) -> faultyClass.exceptionHandler(e, integer))) .join(); + Assert.fail("Should have thrown a RetryException because it exceeded max retry"); } catch (CompletionException e) { Assert.assertTrue(e.getCause() instanceof RetryException); + Assert.assertEquals(RetryParams.MAX_RETRIES, faultyClass.callCount); + Assert.assertTrue(RetryParams.MAX_RETRIES == ((RetryException) e.getCause()).getExceptions().size()); } } @@ -69,7 +72,9 @@ CompletableFuture<String> faultyTask() { callCount++; if (callCount < triesUntilSuccess && exceptionToThrow != null) { - throw exceptionToThrow; + CompletableFuture<String> result = new CompletableFuture<>(); + result.completeExceptionally(exceptionToThrow); + return result; } return CompletableFuture.completedFuture(null);