Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-b003027.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Consider outstanding demand in ByteBufferStoringSubscriber before requesting more - fixes OutOfMemoryIssues in S3CrtRequestBodyStreamAdapter"
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
*/
@SdkInternalApi
public final class S3CrtRequestBodyStreamAdapter implements HttpRequestBodyStream {
private static final long MINIMUM_BYTES_BUFFERED = 1024 * 1024L;
private static final long MINIMUM_BYTES_BUFFERED = 16 * 1024 * 1024L;
private final SdkHttpContentPublisher bodyPublisher;
private final ByteBufferStoringSubscriber requestBodySubscriber;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.http.async.SdkHttpContentPublisher;
import software.amazon.awssdk.utils.async.ByteBufferStoringSubscriber;

class S3CrtRequestBodyStreamAdapterTest {

Expand Down Expand Up @@ -56,6 +58,29 @@ void getRequestData_fillsInputBuffer_publisherBuffersAreSmaller() {
assertThat(inputBuffer.remaining()).isEqualTo(0);
}

@Test
void getRequestData_fillsInputBuffer_limitsOutstandingDemand() {
int minBytesBuffered = 16 * 1024 * 1024;
int inputBufferSize = 1024;

RequestTrackingPublisher requestTrackingPublisher = new RequestTrackingPublisher();
SdkHttpContentPublisher requestBody = requestBody(requestTrackingPublisher, minBytesBuffered);

S3CrtRequestBodyStreamAdapter adapter = new S3CrtRequestBodyStreamAdapter(requestBody);

ByteBuffer inputBuffer = ByteBuffer.allocate(inputBufferSize);
adapter.sendRequestBody(inputBuffer); // initiate the subscription, but no bytes available, makes 1 request

// release 1 request of minBytesBuffered bytes of data, calling onNext (satisfies one request, but then requests 1 more)
requestTrackingPublisher.release(1, minBytesBuffered-100);
assertThat(requestTrackingPublisher.requests()).isEqualTo(2);

// call sendRequestBody, outstandingDemand=1, sizeHint=16*1024*1024-100 + existing data buffered is > our min
// so no more requests will be made
adapter.sendRequestBody(inputBuffer);
assertThat(requestTrackingPublisher.requests()).isEqualTo(2);
}

private static SdkHttpContentPublisher requestBody(Publisher<ByteBuffer> delegate, long size) {
return new SdkHttpContentPublisher() {
@Override
Expand Down Expand Up @@ -114,4 +139,44 @@ public void getRequestData_publisherThrows_wrapsExceptionIfNotRuntimeException()
.isInstanceOf(RuntimeException.class)
.hasCauseInstanceOf(IOException.class);
}

private static class RequestTrackingPublisher implements Publisher<ByteBuffer> {
ByteBufferStoringSubscriber subscriber;
RequestTrackingSubscription subscription = new RequestTrackingSubscription();

@Override
public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
assertThat(subscriber).isInstanceOf(ByteBufferStoringSubscriber.class);
this.subscriber = (ByteBufferStoringSubscriber) subscriber;
this.subscriber.onSubscribe(subscription);
}

// publish up to n requests
public void release(int n, int size) {
for (int i = 0; i < n; i++) {
ByteBuffer buffer = ByteBuffer.allocate(size);
subscriber.onNext(buffer);
}
}

public long requests() {
return subscription.requests;
}
}

private static class RequestTrackingSubscription implements Subscription {

long requests = 0;

@Override
public void request(long n) {
requests += n;
}

@Override
public void cancel() {

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
Expand Down Expand Up @@ -56,13 +57,19 @@ public class ByteBufferStoringSubscriber implements Subscriber<ByteBuffer> {

private final Phaser phaser = new Phaser(1);

private final AtomicInteger outstandingDemand = new AtomicInteger(0);

private volatile long byteBufferSizeHint = 0L;

/**
* The active subscription. Set when {@link #onSubscribe(Subscription)} is invoked.
*/
private Subscription subscription;

/**
* Create a subscriber that stores at least {@code minimumBytesBuffered} in memory for retrieval.
* Create a subscriber that stores at least {@code minimumBytesBuffered} in memory for retrieval. The subscriber will
* only request more from the subscription when fewer bytes are buffered AND in flight requests from the subscription will
* likely be under minimumBytesBuffered.
*/
public ByteBufferStoringSubscriber(long minimumBytesBuffered) {
this.minimumBytesBuffered = Validate.isPositive(minimumBytesBuffered, "Data buffer minimum must be positive");
Expand Down Expand Up @@ -174,13 +181,19 @@ private int transfer(ByteBuffer in, ByteBuffer out) {
public void onSubscribe(Subscription s) {
storingSubscriber.onSubscribe(new DemandIgnoringSubscription(s));
subscription = s;
outstandingDemand.incrementAndGet();
subscription.request(1);
subscriptionLatch.countDown();
}

@Override
public void onNext(ByteBuffer byteBuffer) {
int remaining = byteBuffer.remaining();
outstandingDemand.decrementAndGet();
// atomic update not required here, in a race it does not matter which thread sets this value since it is not being
// incremented, just set.
byteBufferSizeHint = byteBuffer.remaining();

storingSubscriber.onNext(byteBuffer.duplicate());
addBufferedDataAmount(remaining);
phaser.arrive();
Expand All @@ -204,7 +217,9 @@ private void addBufferedDataAmount(long amountToAdd) {
}

private void maybeRequestMore(long currentDataBuffered) {
if (currentDataBuffered < minimumBytesBuffered) {
long dataBufferedAndInFlight = currentDataBuffered + (byteBufferSizeHint * outstandingDemand.get());
if (dataBufferedAndInFlight < minimumBytesBuffered) {
outstandingDemand.incrementAndGet();
subscription.request(1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ public void doesNotRequestMoreThanMaxBytes() {
verifyNoMoreInteractions(subscription);
}

@Test
public void doesNotRequestMoreWhenInflightMoreThanMinBytes() {
ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(5);

subscriber.onSubscribe(subscription); // request 1, demand = 1
subscriber.onNext(fullByteBufferOfSize(3)); // demand = 0, sizeHint=3
subscriber.transferTo(emptyByteBufferOfSize(1)); // requests more, demand = 1
subscriber.transferTo(emptyByteBufferOfSize(1)); // requests more, demand = 2
verify(subscription, times(3)).request(1);

//sizeHint=3, demand=2, dataBufferedAndInFlight=6. 6 > 5, so no new request
subscriber.transferTo(emptyByteBufferOfSize(1));
verifyNoMoreInteractions(subscription);
}

@Test
public void canStoreMoreThanMaxBytesButWontAskForMoreUntilBelowMax() {
ByteBufferStoringSubscriber subscriber = new ByteBufferStoringSubscriber(3);
Expand Down
Loading