From 42e81007ad1238bb7eb4d70dad735a5089b351ac Mon Sep 17 00:00:00 2001 From: John Viegas Date: Tue, 4 Mar 2025 16:13:48 -0800 Subject: [PATCH] Refactor RequestBatchBuffer to seperate out flush policy and batch storage --- ...RequestBatchManagerSqsIntegrationTest.java | 1 + .../services/sqs/SqsSendMessageApp.java | 69 ++++++ .../services/sqs/SqsSendMessageOld.java | 140 ++++++++++++ .../batchmanager/BatchEntryIdGenerator.java | 58 +++++ .../internal/batchmanager/BatchingMap.java | 19 +- .../internal/batchmanager/FlushPolicy.java | 76 +++++++ .../internal/batchmanager/FlushScheduler.java | 53 +++++ .../batchmanager/RequestBatchBuffer.java | 207 +++++++++--------- .../batchmanager/RequestBatchStorage.java | 104 +++++++++ .../batchmanager/RequestBatchBufferTest.java | 133 ++++++++--- 10 files changed, 712 insertions(+), 148 deletions(-) create mode 100644 services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageApp.java create mode 100644 services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageOld.java create mode 100644 services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchEntryIdGenerator.java create mode 100644 services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushPolicy.java create mode 100644 services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushScheduler.java create mode 100644 services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchStorage.java diff --git a/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/RequestBatchManagerSqsIntegrationTest.java b/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/RequestBatchManagerSqsIntegrationTest.java index 5faf289cda17..362dc4f6351e 100644 --- a/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/RequestBatchManagerSqsIntegrationTest.java +++ b/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/RequestBatchManagerSqsIntegrationTest.java @@ -91,6 +91,7 @@ public void setUp() { @AfterEach public void tearDown() { + purgeQueue(defaultQueueUrl); batchManager.close(); } diff --git a/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageApp.java b/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageApp.java new file mode 100644 index 000000000000..ea3e74f2d233 --- /dev/null +++ b/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageApp.java @@ -0,0 +1,69 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.sqs; + +import org.junit.Ignore; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.sqs.batchmanager.SqsAsyncBatchManager; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; + +/** + * Tests SQS message sending with improved memory management. + */ +@Ignore +public class SqsSendMessageApp { + private static final String QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/248213382692/myTestQueue0"; + private static final int MESSAGE_SIZE = 12_000; + private static final int MESSAGE_COUNT = 270000; + private static final int DELAY_MS = 30; + private static final int BATCH_SIZE = 1000; + + @Test + void testBatchSize() throws Exception { + // Create SQS client and batch manager + SqsAsyncClient sqsAsyncClient = SqsAsyncClient.builder().build(); + SqsAsyncBatchManager batchManager = sqsAsyncClient.batchManager(); + + // Create message template + String messageBody = createLargeString('a', MESSAGE_SIZE); + SendMessageRequest messageTemplate = SendMessageRequest.builder() + .queueUrl(QUEUE_URL) + .messageBody(messageBody) + .build(); + + + while (true) { + batchManager.sendMessage(messageTemplate).whenComplete((response, error) -> { + if (error != null) { + System.err.println("Error sending message: " + error.getMessage()); + } else { + System.out.println("Message sent successfully: " + response.messageId()); + } + }); + + Thread.sleep(DELAY_MS); + } + } + + /** + * Creates a string of specified length filled with the given character. + */ + private String createLargeString(char ch, int length) { + char[] chars = new char[length]; + java.util.Arrays.fill(chars, ch); + return new String(chars); + } +} diff --git a/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageOld.java b/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageOld.java new file mode 100644 index 000000000000..2d8168ed5e7e --- /dev/null +++ b/services/sqs/src/it/java/software/amazon/awssdk/services/sqs/SqsSendMessageOld.java @@ -0,0 +1,140 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.sqs; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.Ignore; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.services.sqs.batchmanager.SqsAsyncBatchManager; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageResponse; + +/** + * Tests SQS message sending with size monitoring. + */ +@Ignore +public class SqsSendMessageOld { + + String QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/248213382692/myTestQueue0"; + private static final int MESSAGE_SIZE = 12_000; + private static final int MESSAGE_COUNT = 270000; + private static final int DELAY_MS = 30; + + @Test + void testBatchSize() throws Exception { + ExecutionInterceptor captureMessageSizeInterceptor = new CaptureMessageSizeInterceptor(); + + SqsAsyncClient sqsAsyncClient = SqsAsyncClient.builder() + // .overrideConfiguration(o -> o.addExecutionInterceptor(captureMessageSizeInterceptor)) + .build(); + + String messageBody = createLargeString('a', MESSAGE_SIZE); + SqsAsyncBatchManager sqsAsyncBatchManager = sqsAsyncClient.batchManager(); + + SendMessageRequest sendMessageRequest = SendMessageRequest.builder() + .queueUrl(QUEUE_URL) + .messageBody(messageBody) + .delaySeconds(20) + .build(); + + List> futures = sendMessages( + sqsAsyncBatchManager, sendMessageRequest, MESSAGE_COUNT); + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()])).join(); + + System.out.println("All messages sent successfully"); + } + + /** + * Sends multiple messages with a delay between each. + * + * @param batchManager The batch manager to use + * @param messageRequest The message request template + * @param count Number of messages to send + * @return List of futures for the send operations + * @throws InterruptedException If thread is interrupted during sleep + */ + private List> sendMessages( + SqsAsyncBatchManager batchManager, + SendMessageRequest messageRequest, + int count) throws InterruptedException { + + List> futures = new ArrayList<>(); + + for (int i = 0; i < count; i++) { + CompletableFuture future = batchManager.sendMessage(messageRequest) + .whenComplete((response, error) -> { + if (error != null) { + error.printStackTrace(); + } else { + System.out.println("Message sent with ID: " + response.messageId()); + } + }); + + futures.add(future); + + if (i < count - 1) { + Thread.sleep(DELAY_MS); + } + } + + return futures; + } + + /** + * Creates a string of specified length filled with the given character. + * + * @param ch Character to fill the string with + * @param length Length of the string to create + * @return The generated string + */ + private String createLargeString(char ch, int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(ch); + } + return sb.toString(); + } + + /** + * Interceptor that captures and logs message sizes in batch requests. + */ + static class CaptureMessageSizeInterceptor implements ExecutionInterceptor { + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + if (context.request() instanceof SendMessageBatchRequest) { + SendMessageBatchRequest batchRequest = (SendMessageBatchRequest) context.request(); + + System.out.println("Batch contains " + batchRequest.entries().size() + " messages"); + + int totalMessageBodySize = 0; + for (SendMessageBatchRequestEntry entry : batchRequest.entries()) { + int messageSize = entry.messageBody().length(); + totalMessageBodySize += messageSize; + System.out.println("Message body size: " + messageSize + " bytes"); + } + + System.out.println("Total message bodies size: " + totalMessageBodySize + " bytes"); + } + } + } +} diff --git a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchEntryIdGenerator.java b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchEntryIdGenerator.java new file mode 100644 index 000000000000..27951c7f04e8 --- /dev/null +++ b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchEntryIdGenerator.java @@ -0,0 +1,58 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.sqs.internal.batchmanager; + +import java.util.Map; +import java.util.concurrent.locks.ReentrantLock; +import software.amazon.awssdk.annotations.SdkInternalApi; + +/** + * Manages the generation of unique IDs for batch entries. + */ +@SdkInternalApi +class BatchEntryIdGenerator { + private int nextId = 0; + private int nextBatchEntry = 0; + private final ReentrantLock idLock = new ReentrantLock(); + + public String nextId() { + idLock.lock(); + try { + if (nextId == Integer.MAX_VALUE) { + nextId = 0; + } + return Integer.toString(nextId++); + } finally { + idLock.unlock(); + } + } + + public boolean hasNextBatchEntry(Map contextMap) { + return contextMap.containsKey(Integer.toString(nextBatchEntry)); + } + + public String nextBatchEntry() { + idLock.lock(); + try { + if (nextBatchEntry == Integer.MAX_VALUE) { + nextBatchEntry = 0; + } + return Integer.toString(nextBatchEntry++); + } finally { + idLock.unlock(); + } + } +} diff --git a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchingMap.java b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchingMap.java index 171b76aa4bff..439934c936ca 100644 --- a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchingMap.java +++ b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/BatchingMap.java @@ -51,7 +51,12 @@ public void put(String batchKey, Supplier> scheduleFlush, Req if (batchContextMap.size() == maxBatchKeys) { throw new IllegalStateException("Reached MaxBatchKeys of: " + maxBatchKeys); } - return new RequestBatchBuffer<>(scheduleFlush.get(), maxBatchSize, maxBatchBytesSize, maxBufferSize); + return RequestBatchBuffer.builder() + .scheduledFlush(scheduleFlush.get()) + .maxBatchItems(maxBatchSize) + .maxBatchSizeInBytes(maxBatchBytesSize) + .maxBufferSize(maxBufferSize) + .build(); }).put(request, response); } @@ -68,17 +73,17 @@ public void forEach(BiConsumer> } public Map> flushableRequests(String batchKey) { - return batchContextMap.get(batchKey).flushableRequests(); + return batchContextMap.get(batchKey).extractBatchIfNeeded(); } public Map> flushableRequestsOnByteLimitBeforeAdd(String batchKey, RequestT request) { - return batchContextMap.get(batchKey).flushableRequestsOnByteLimitBeforeAdd(request); + return batchContextMap.get(batchKey).getFlushableBatchIfSizeExceeded(request); } public Map> flushableScheduledRequests(String batchKey, int maxBatchItems) { - return batchContextMap.get(batchKey).flushableScheduledRequests(maxBatchItems); + return batchContextMap.get(batchKey).extractEntriesForScheduledFlush(maxBatchItems); } public void cancelScheduledFlush(String batchKey) { @@ -86,10 +91,8 @@ public void cancelScheduledFlush(String batchKey) { } public void clear() { - for (Map.Entry> entry : batchContextMap.entrySet()) { - String key = entry.getKey(); - entry.getValue().clear(); - batchContextMap.remove(key); + for (RequestBatchBuffer buffer : batchContextMap.values()) { + buffer.clear(); } batchContextMap.clear(); } diff --git a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushPolicy.java b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushPolicy.java new file mode 100644 index 000000000000..93ca6a98bb4f --- /dev/null +++ b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushPolicy.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.sqs.internal.batchmanager; + + +import java.util.Map; +import software.amazon.awssdk.annotations.SdkInternalApi; + +/** + * Determines when a batch should be flushed based on various criteria. + */ +@SdkInternalApi +class FlushPolicy { + private final int maxBatchItems; + private final int maxBatchSizeInBytes; + + FlushPolicy(int maxBatchItems, int maxBatchSizeInBytes) { + this.maxBatchItems = maxBatchItems; + this.maxBatchSizeInBytes = maxBatchSizeInBytes; + } + + public int getMaxBatchItems() { + return maxBatchItems; + } + + // Updated method signature to use the same generic types + public boolean shouldFlush(Map> entries) { + return isBatchSizeLimitReached(entries) || isByteSizeThresholdCrossed(entries, 0); + } + + // Updated method signature to use the same generic types + public boolean shouldFlushBeforeAdd( + Map> entries, + RequestT incomingRequest) { + if (maxBatchSizeInBytes > 0 && !entries.isEmpty()) { + int incomingRequestBytes = RequestPayloadCalculator.calculateMessageSize(incomingRequest).orElse(0); + return isByteSizeThresholdCrossed(entries, incomingRequestBytes); + } + return false; + } + + // Updated method signature to use the same generic types + private boolean isBatchSizeLimitReached( + Map> entries) { + return entries.size() >= maxBatchItems; + } + + // Updated method signature to use the same generic types + private boolean isByteSizeThresholdCrossed( + Map> entries, + int additionalBytes) { + if (maxBatchSizeInBytes < 0) { + return false; + } + + int totalPayloadSize = entries.values().stream() + .map(BatchingExecutionContext::responsePayloadByteSize) + .mapToInt(opt -> opt.orElse(0)) + .sum() + additionalBytes; + + return totalPayloadSize > maxBatchSizeInBytes; + } +} diff --git a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushScheduler.java b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushScheduler.java new file mode 100644 index 000000000000..f89aa149c267 --- /dev/null +++ b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/FlushScheduler.java @@ -0,0 +1,53 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.sqs.internal.batchmanager; + +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.locks.ReentrantLock; +import software.amazon.awssdk.annotations.SdkInternalApi; + +/** + * Manages the scheduled flush tasks. + */ +@SdkInternalApi +class FlushScheduler { + private ScheduledFuture scheduledFlush; + private final ReentrantLock schedulerLock = new ReentrantLock(); + + FlushScheduler(ScheduledFuture initialScheduledFlush) { + this.scheduledFlush = initialScheduledFlush; + } + + public void updateScheduledFlush(ScheduledFuture newScheduledFlush) { + schedulerLock.lock(); + try { + this.scheduledFlush = newScheduledFlush; + } finally { + schedulerLock.unlock(); + } + } + + public void cancelScheduledFlush() { + schedulerLock.lock(); + try { + if (scheduledFlush != null) { + scheduledFlush.cancel(false); + } + } finally { + schedulerLock.unlock(); + } + } +} diff --git a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchBuffer.java b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchBuffer.java index d13b32c29e1e..3654fa4faa96 100644 --- a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchBuffer.java +++ b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchBuffer.java @@ -18,148 +18,145 @@ import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledFuture; -import java.util.stream.Collectors; import software.amazon.awssdk.annotations.SdkInternalApi; +/** + * Main facade class that coordinates the batch buffer functionality. + */ @SdkInternalApi public final class RequestBatchBuffer { - private final Object flushLock = new Object(); - - private final Map> idToBatchContext; - private final int maxBatchItems; - private final int maxBufferSize; - private final int maxBatchSizeInBytes; - /** - * Batch entries in a batch request require a unique ID so nextId keeps track of the ID to assign to the next - * BatchingExecutionContext. For simplicity, the ID is just an integer that is incremented everytime a new request and - * response pair is received. - */ - private int nextId; - /** - * Keeps track of the ID of the next entry to be added in a batch request. This ID does not necessarily correlate to a request - * that already exists in the idToBatchContext map since it refers to the next entry (ex. if the last entry added to - * idToBatchContext had an id of 22, nextBatchEntry will have a value of 23). - */ - private int nextBatchEntry; - - /** - * The scheduled flush tasks associated with this batchBuffer. - */ - private ScheduledFuture scheduledFlush; - - public RequestBatchBuffer(ScheduledFuture scheduledFlush, - int maxBatchItems, int maxBatchSizeInBytes, int maxBufferSize) { - this.idToBatchContext = new ConcurrentHashMap<>(); - this.nextId = 0; - this.nextBatchEntry = 0; - this.scheduledFlush = scheduledFlush; - this.maxBatchItems = maxBatchItems; - this.maxBufferSize = maxBufferSize; - this.maxBatchSizeInBytes = maxBatchSizeInBytes; + private final RequestBatchStorage requestBatchStorage; + private final FlushPolicy flushPolicy; + private final FlushScheduler flushScheduler; + private final BatchEntryIdGenerator idGenerator; + + private RequestBatchBuffer(Builder builder) { + this.requestBatchStorage = new RequestBatchStorage<>(builder.maxBufferSize); + this.flushPolicy = new FlushPolicy<>(builder.maxBatchItems, builder.maxBatchSizeInBytes); + this.flushScheduler = new FlushScheduler(builder.scheduledFlush); + this.idGenerator = new BatchEntryIdGenerator(); } - public Map> flushableRequests() { - synchronized (flushLock) { - return (isByteSizeThresholdCrossed(0) || isMaxBatchSizeLimitReached()) - ? extractFlushedEntries(maxBatchItems) - : Collections.emptyMap(); - } + public static Builder builder() { + return new Builder<>(); } - - private boolean isMaxBatchSizeLimitReached() { - return idToBatchContext.size() >= maxBatchItems; - } - - public Map> flushableRequestsOnByteLimitBeforeAdd(RequestT request) { - synchronized (flushLock) { - if (maxBatchSizeInBytes > 0 && !idToBatchContext.isEmpty()) { - int incomingRequestBytes = RequestPayloadCalculator.calculateMessageSize(request).orElse(0); - if (isByteSizeThresholdCrossed(incomingRequestBytes)) { - return extractFlushedEntries(maxBatchItems); - } + /** + * Returns entries that should be flushed before adding a new request based on byte size constraints. + */ + public Map> getFlushableBatchIfSizeExceeded(RequestT request) { + requestBatchStorage.getLock().lock(); + try { + if (flushPolicy.shouldFlushBeforeAdd(requestBatchStorage.getAllEntries(), request)) { + return requestBatchStorage.extractEntries(flushPolicy.getMaxBatchItems(), idGenerator); } return Collections.emptyMap(); + } finally { + requestBatchStorage.getLock().unlock(); } } - private boolean isByteSizeThresholdCrossed(int incomingRequestBytes) { - if (maxBatchSizeInBytes < 0) { - return false; - } - int totalPayloadSize = idToBatchContext.values().stream() - .map(BatchingExecutionContext::responsePayloadByteSize) - .mapToInt(opt -> opt.orElse(0)) - .sum() + incomingRequestBytes; - return totalPayloadSize > maxBatchSizeInBytes; - } - - public Map> flushableScheduledRequests(int maxBatchItems) { - synchronized (flushLock) { - if (!idToBatchContext.isEmpty()) { - return extractFlushedEntries(maxBatchItems); + /** + * Returns entries that should be flushed due to scheduled flush. + */ + public Map> extractEntriesForScheduledFlush(int maxBatchItems) { + requestBatchStorage.getLock().lock(); + try { + if (!requestBatchStorage.isEmpty()) { + return requestBatchStorage.extractEntries(maxBatchItems, idGenerator); } return Collections.emptyMap(); + } finally { + requestBatchStorage.getLock().unlock(); } } - private Map> extractFlushedEntries(int maxBatchItems) { - LinkedHashMap> requestEntries = new LinkedHashMap<>(); - String nextEntry; - while (requestEntries.size() < maxBatchItems && hasNextBatchEntry()) { - nextEntry = nextBatchEntry(); - requestEntries.put(nextEntry, idToBatchContext.get(nextEntry)); - idToBatchContext.remove(nextEntry); - } - return requestEntries; - } - - public void put(RequestT request, CompletableFuture response) { - synchronized (this) { - if (idToBatchContext.size() == maxBufferSize) { - throw new IllegalStateException("Reached MaxBufferSize of: " + maxBufferSize); - } - - if (nextId == Integer.MAX_VALUE) { - nextId = 0; + /** + * Returns entries that should be flushed based on current buffer state. + */ + public Map> extractBatchIfNeeded() { + requestBatchStorage.getLock().lock(); + try { + if (flushPolicy.shouldFlush(requestBatchStorage.getAllEntries())) { + return requestBatchStorage.extractEntries(flushPolicy.getMaxBatchItems(), idGenerator); } - String id = Integer.toString(nextId++); - idToBatchContext.put(id, new BatchingExecutionContext<>(request, response)); + return Collections.emptyMap(); + } finally { + requestBatchStorage.getLock().unlock(); } } - private boolean hasNextBatchEntry() { - return idToBatchContext.containsKey(Integer.toString(nextBatchEntry)); - } - private String nextBatchEntry() { - if (nextBatchEntry == Integer.MAX_VALUE) { - nextBatchEntry = 0; - } - return Integer.toString(nextBatchEntry++); + /** + * Adds a request to the buffer. + */ + public void put(RequestT request, CompletableFuture response) { + String id = idGenerator.nextId(); + requestBatchStorage.put(id, new BatchingExecutionContext<>(request, response)); } + /** + * Updates the scheduled flush task. + */ public void putScheduledFlush(ScheduledFuture scheduledFlush) { - this.scheduledFlush = scheduledFlush; + flushScheduler.updateScheduledFlush(scheduledFlush); } + /** + * Cancels the scheduled flush task. + */ public void cancelScheduledFlush() { - scheduledFlush.cancel(false); + flushScheduler.cancelScheduledFlush(); } + /** + * Returns all response futures in the buffer. + */ public Collection> responses() { - return idToBatchContext.values() - .stream() - .map(BatchingExecutionContext::response) - .collect(Collectors.toList()); + return requestBatchStorage.getAllResponses(); } + /** + * Clears all entries from the buffer. + */ public void clear() { - idToBatchContext.clear(); + requestBatchStorage.clear(); + } + + /** + * Builder for RequestBatchBuffer. + */ + public static final class Builder { + private ScheduledFuture scheduledFlush; + private int maxBatchItems; + private int maxBatchSizeInBytes; + private int maxBufferSize; + + public Builder scheduledFlush(ScheduledFuture scheduledFlush) { + this.scheduledFlush = scheduledFlush; + return this; + } + + public Builder maxBatchItems(int maxBatchItems) { + this.maxBatchItems = maxBatchItems; + return this; + } + + public Builder maxBatchSizeInBytes(int maxBatchSizeInBytes) { + this.maxBatchSizeInBytes = maxBatchSizeInBytes; + return this; + } + + public Builder maxBufferSize(int maxBufferSize) { + this.maxBufferSize = maxBufferSize; + return this; + } + + public RequestBatchBuffer build() { + return new RequestBatchBuffer<>(this); + } } } diff --git a/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchStorage.java b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchStorage.java new file mode 100644 index 000000000000..8b95e3fcc890 --- /dev/null +++ b/services/sqs/src/main/java/software/amazon/awssdk/services/sqs/internal/batchmanager/RequestBatchStorage.java @@ -0,0 +1,104 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.sqs.internal.batchmanager; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.SdkInternalApi; + +/** + * Responsible for storing batch entries and providing operations to access and extract them. + */ +@SdkInternalApi +class RequestBatchStorage { + private final Map> idToBatchContext; + private final int maxBufferSize; + private final ReentrantLock lock = new ReentrantLock(); + + RequestBatchStorage(int maxBufferSize) { + this.idToBatchContext = new ConcurrentHashMap<>(); + this.maxBufferSize = maxBufferSize; + } + + public ReentrantLock getLock() { + return lock; + } + + public void put(String id, BatchingExecutionContext context) { + lock.lock(); + try { + if (idToBatchContext.size() == maxBufferSize) { + throw new IllegalStateException("Reached MaxBufferSize of: " + maxBufferSize); + } + idToBatchContext.put(id, context); + } finally { + lock.unlock(); + } + } + + public Map> getAllEntries() { + // No need for locking here as we're returning an unmodifiable view + return Collections.unmodifiableMap(idToBatchContext); + } + + public boolean isEmpty() { + // ConcurrentHashMap's isEmpty is thread-safe + return idToBatchContext.isEmpty(); + } + + public Map> extractEntries(int maxEntries, + BatchEntryIdGenerator idGenerator) { + Map> extractedEntries = + new ConcurrentHashMap<>(Math.min(maxEntries, idToBatchContext.size())); + + String nextEntry; + int count = 0; + + while (count < maxEntries && idGenerator.hasNextBatchEntry(idToBatchContext)) { + nextEntry = idGenerator.nextBatchEntry(); + BatchingExecutionContext context = idToBatchContext.get(nextEntry); + if (context != null) { + extractedEntries.put(nextEntry, context); + idToBatchContext.remove(nextEntry); + count++; + } + } + + return extractedEntries; + } + + public Collection> getAllResponses() { + // Using ConcurrentHashMap's thread-safe iteration + return idToBatchContext.values() + .stream() + .map(BatchingExecutionContext::response) + .collect(Collectors.toList()); + } + + public void clear() { + lock.lock(); + try { + idToBatchContext.clear(); + } finally { + lock.unlock(); + } + } +} diff --git a/services/sqs/src/test/java/software/amazon/awssdk/services/sqs/batchmanager/RequestBatchBufferTest.java b/services/sqs/src/test/java/software/amazon/awssdk/services/sqs/batchmanager/RequestBatchBufferTest.java index 0829d8cd5693..0a6d86191957 100644 --- a/services/sqs/src/test/java/software/amazon/awssdk/services/sqs/batchmanager/RequestBatchBufferTest.java +++ b/services/sqs/src/test/java/software/amazon/awssdk/services/sqs/batchmanager/RequestBatchBufferTest.java @@ -45,35 +45,55 @@ void setUp() { @Test void whenPutRequestThenBufferContainsRequest() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(10) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); CompletableFuture response = new CompletableFuture<>(); batchBuffer.put("request1", response); assertEquals(1, batchBuffer.responses().size()); } @Test - void whenFlushableRequestsThenReturnRequestsUpToMaxBatchItems() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 1, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + void whenExtractBatchIfNeededUpToMaxBatchItems() { + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(1) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); CompletableFuture response = new CompletableFuture<>(); batchBuffer.put("request1", response); - Map> flushedRequests = batchBuffer.flushableRequests(); + Map> flushedRequests = batchBuffer.extractBatchIfNeeded(); assertEquals(1, flushedRequests.size()); assertTrue(flushedRequests.containsKey("0")); } @Test - void whenFlushableScheduledRequestsThenReturnAllRequests() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + void whenExtractEntriesForScheduledRequestsThenReturnAllFlush() { + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(10) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); CompletableFuture response = new CompletableFuture<>(); batchBuffer.put("request1", response); - Map> flushedRequests = batchBuffer.flushableScheduledRequests(1); + Map> flushedRequests = batchBuffer.extractEntriesForScheduledFlush(1); assertEquals(1, flushedRequests.size()); assertTrue(flushedRequests.containsKey("0")); } @Test void whenMaxBufferSizeReachedThenThrowException() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 3, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, 10); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(3) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(10) + .build(); for (int i = 0; i < 10; i++) { batchBuffer.put("request" + i, new CompletableFuture<>()); } @@ -82,7 +102,12 @@ void whenMaxBufferSizeReachedThenThrowException() { @Test void whenPutScheduledFlushThenFlushIsSet() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(10) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); ScheduledFuture newScheduledFlush = mock(ScheduledFuture.class); batchBuffer.putScheduledFlush(newScheduledFlush); assertNotNull(newScheduledFlush); @@ -90,14 +115,24 @@ void whenPutScheduledFlushThenFlushIsSet() { @Test void whenCancelScheduledFlushThenFlushIsCancelled() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(10) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); batchBuffer.cancelScheduledFlush(); verify(scheduledFlush).cancel(false); } @Test void whenGetResponsesThenReturnAllResponses() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(10) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); CompletableFuture response1 = new CompletableFuture<>(); CompletableFuture response2 = new CompletableFuture<>(); batchBuffer.put("request1", response1); @@ -110,7 +145,12 @@ void whenGetResponsesThenReturnAllResponses() { @Test void whenClearBufferThenBufferIsEmpty() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(10) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); CompletableFuture response = new CompletableFuture<>(); batchBuffer.put("request1", response); batchBuffer.clear(); @@ -119,61 +159,88 @@ void whenClearBufferThenBufferIsEmpty() { @Test void whenExtractFlushedEntriesThenReturnCorrectEntries() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 5, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(5) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); for (int i = 0; i < 5; i++) { batchBuffer.put("request" + i, new CompletableFuture<>()); } - Map> flushedEntries = batchBuffer.flushableRequests(); + Map> flushedEntries = batchBuffer.extractBatchIfNeeded(); assertEquals(5, flushedEntries.size()); } @Test void whenHasNextBatchEntryThenReturnTrue() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 1, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(1) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); batchBuffer.put("request1", new CompletableFuture<>()); - assertTrue(batchBuffer.flushableRequests().containsKey("0")); + assertTrue(batchBuffer.extractBatchIfNeeded().containsKey("0")); } - @Test void whenNextBatchEntryThenReturnNextEntryId() { - batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 1, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + batchBuffer = RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(1) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); batchBuffer.put("request1", new CompletableFuture<>()); - assertEquals("0", batchBuffer.flushableRequests().keySet().iterator().next()); + assertEquals("0", batchBuffer.extractBatchIfNeeded().keySet().iterator().next()); } @Test void whenRequestPassedWithLessBytesinArgs_thenCheckForSizeOnly_andDonotFlush() { - RequestBatchBuffer batchBuffer - = new RequestBatchBuffer<>(scheduledFlush, 5, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + RequestBatchBuffer batchBuffer = + RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(5) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); for (int i = 0; i < 5; i++) { batchBuffer.put(SendMessageRequest.builder().build(), new CompletableFuture<>()); } Map> flushedEntries = - batchBuffer.flushableRequestsOnByteLimitBeforeAdd(SendMessageRequest.builder().messageBody("Hi").build()); + batchBuffer.getFlushableBatchIfSizeExceeded(SendMessageRequest.builder().messageBody("Hi").build()); assertEquals(0, flushedEntries.size()); } - - @Test void testFlushWhenPayloadExceedsMaxSize() { - RequestBatchBuffer batchBuffer - = new RequestBatchBuffer<>(scheduledFlush, 5, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + RequestBatchBuffer batchBuffer = + RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(5) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); String largeMessageBody = createLargeString('a',245_760); batchBuffer.put(SendMessageRequest.builder().messageBody(largeMessageBody).build(), new CompletableFuture<>()); Map> flushedEntries = - batchBuffer.flushableRequestsOnByteLimitBeforeAdd(SendMessageRequest.builder().messageBody("NewMessage").build()); + batchBuffer.getFlushableBatchIfSizeExceeded(SendMessageRequest.builder().messageBody("NewMessage").build()); assertEquals(1, flushedEntries.size()); } @Test void testFlushWhenCumulativePayloadExceedsMaxSize() { - RequestBatchBuffer batchBuffer - = new RequestBatchBuffer<>(scheduledFlush, 5, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize); + RequestBatchBuffer batchBuffer = + RequestBatchBuffer.builder() + .scheduledFlush(scheduledFlush) + .maxBatchItems(5) + .maxBatchSizeInBytes(MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES) + .maxBufferSize(maxBufferSize) + .build(); String largeMessageBody = createLargeString('a',130_000); batchBuffer.put(SendMessageRequest.builder().messageBody(largeMessageBody).build(), @@ -181,13 +248,12 @@ void testFlushWhenCumulativePayloadExceedsMaxSize() { batchBuffer.put(SendMessageRequest.builder().messageBody(largeMessageBody).build(), new CompletableFuture<>()); Map> flushedEntries = - batchBuffer.flushableRequestsOnByteLimitBeforeAdd(SendMessageRequest.builder().messageBody("NewMessage").build()); + batchBuffer.getFlushableBatchIfSizeExceeded(SendMessageRequest.builder().messageBody("NewMessage").build()); //Flushes both the messages since thier sum is greater than 256Kb assertEquals(2, flushedEntries.size()); } - private String createLargeString(char ch, int length) { StringBuilder sb = new StringBuilder(length); for (int i = 0; i < length; i++) { @@ -195,7 +261,4 @@ private String createLargeString(char ch, int length) { } return sb.toString(); } - - - -} \ No newline at end of file +}