diff --git a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/spi/v1/MessageDispatcher.java b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/spi/v1/MessageDispatcher.java index b6b7318e30c7..f349118bcbb8 100644 --- a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/spi/v1/MessageDispatcher.java +++ b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/spi/v1/MessageDispatcher.java @@ -16,8 +16,8 @@ package com.google.cloud.pubsub.spi.v1; -import com.google.api.gax.core.FlowController; import com.google.api.gax.core.ApiClock; +import com.google.api.gax.core.FlowController; import com.google.api.stats.Distribution; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; @@ -28,6 +28,7 @@ import com.google.pubsub.v1.PubsubMessage; import com.google.pubsub.v1.ReceivedMessage; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; @@ -260,23 +261,47 @@ public int getMessageDeadlineSeconds() { } public void processReceivedMessages(List responseMessages) { - int receivedMessagesCount = responseMessages.size(); - if (receivedMessagesCount == 0) { + if (responseMessages.isEmpty()) { return; } - Instant now = new Instant(clock.millisTime()); - int totalByteCount = 0; + final ArrayList ackHandlers = new ArrayList<>(responseMessages.size()); for (ReceivedMessage pubsubMessage : responseMessages) { - int messageSize = pubsubMessage.getMessage().getSerializedSize(); - totalByteCount += messageSize; - ackHandlers.add(new AckHandler(pubsubMessage.getAckId(), messageSize)); + int size = pubsubMessage.getMessage().getSerializedSize(); + AckHandler handler = new AckHandler(pubsubMessage.getAckId(), size); + ackHandlers.add(handler); } + + Instant now = new Instant(clock.millisTime()); Instant expiration = now.plus(messageDeadlineSeconds * 1000); logger.log( Level.FINER, "Received {0} messages at {1}", new Object[] {responseMessages.size(), now}); + // We must add the ackHandlers to outstandingAckHandlers before setting up the deadline extension alarm. + // Otherwise, the alarm might go off before we can add the handlers. + synchronized (outstandingAckHandlers) { + // AckDeadlineAlarm modifies lists in outstandingAckHandlers in-place and might run at any time. + // We will also later iterate over ackHandlers when we give messages to user code. + // We must create a new list to pass to outstandingAckHandlers, + // so that we can't iterate and modify the list concurrently. + ArrayList ackHandlersCopy = new ArrayList<>(ackHandlers); + outstandingAckHandlers.add( + new ExtensionJob(expiration, INITIAL_ACK_DEADLINE_EXTENSION_SECONDS, ackHandlersCopy)); + } + + // Deadline extension must be set up before we reserve flow control. + // Flow control might block for a while, and extension will keep messages from expiring. + setupNextAckDeadlineExtensionAlarm(expiration); + + // Reserving flow control must happen before we give the messages to the user, + // otherwise the user code might be given too many messages to process at once. + try { + flowController.reserve(responseMessages.size(), getTotalMessageSize(responseMessages)); + } catch (FlowController.FlowControlException e) { + throw new IllegalStateException("Flow control unexpected exception", e); + } messagesWaiter.incrementPendingMessages(responseMessages.size()); + Iterator acksIterator = ackHandlers.iterator(); for (ReceivedMessage userMessage : responseMessages) { final PubsubMessage message = userMessage.getMessage(); @@ -302,18 +327,14 @@ public void run() { } }); } + } - synchronized (outstandingAckHandlers) { - outstandingAckHandlers.add( - new ExtensionJob(expiration, INITIAL_ACK_DEADLINE_EXTENSION_SECONDS, ackHandlers)); - } - setupNextAckDeadlineExtensionAlarm(expiration); - - try { - flowController.reserve(receivedMessagesCount, totalByteCount); - } catch (FlowController.FlowControlException unexpectedException) { - throw new IllegalStateException("Flow control unexpected exception", unexpectedException); + private static int getTotalMessageSize(Collection messages) { + int total = 0; + for (ReceivedMessage message : messages) { + total += message.getMessage().getSerializedSize(); } + return total; } private void setupPendingAcksAlarm() {