Skip to content

Commit fb5c1e1

Browse files
idelpivnitskiylawrencewang49
authored andcommitted
Streaming gRPC client can leak responses if malformed response received (#3354)
Motivation: When gRPC client receives a response, it first validates that it's formed according to gRPC specification. If any headers or trailers are missing, `validateResponseAndGetPayload` will throw an exception. In this case, we leak undrained response payload body. Modifications: 1. Try catch logic of `validateResponseAndGetPayload`, subscribe and cancel response message body in case of unexpected exceptions. 2. Enhance `ProtocolCompatibilityTest` to validate we never leak responses across all tests. Result: Responses are properly drained even when we receive malformed responses.
1 parent a3ebf2b commit fb5c1e1

File tree

2 files changed

+83
-21
lines changed

2 files changed

+83
-21
lines changed

servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/DefaultGrpcClientCallFactory.java

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.servicetalk.concurrent.api.AsyncContext;
2020
import io.servicetalk.concurrent.api.Completable;
2121
import io.servicetalk.concurrent.api.Publisher;
22+
import io.servicetalk.concurrent.internal.CancelImmediatelySubscriber;
2223
import io.servicetalk.encoding.api.BufferDecoder;
2324
import io.servicetalk.encoding.api.BufferDecoderGroup;
2425
import io.servicetalk.encoding.api.BufferEncoder;
@@ -41,6 +42,7 @@
4142
import java.util.concurrent.TimeUnit;
4243
import javax.annotation.Nullable;
4344

45+
import static io.servicetalk.concurrent.api.SourceAdapters.toSource;
4446
import static io.servicetalk.concurrent.internal.BlockingIterables.singletonBlockingIterable;
4547
import static io.servicetalk.encoding.api.Identity.identityEncoder;
4648
import static io.servicetalk.grpc.api.GrpcHeaderValues.GRPC_CONTENT_TYPE_PROTO_SUFFIX;
@@ -154,11 +156,17 @@ public <Req, Resp> StreamingClientCall<Req, Resp> newStreamingCall(
154156
streamingHttpClient.executionContext().bufferAllocator()));
155157
return streamingHttpClient.request(httpRequest)
156158
.flatMapPublisher(response -> {
157-
extractResponseContext(response, metadata);
158-
return validateResponseAndGetPayload(response, responseContentType,
159-
streamingHttpClient.executionContext().bufferAllocator(),
160-
readGrpcMessageEncodingRaw(response.headers(), deserializerIdentity, deserializers,
161-
GrpcStreamingDeserializer::messageEncoding), httpRequest.requestTarget());
159+
try {
160+
161+
extractResponseContext(response, metadata);
162+
return validateResponseAndGetPayload(response, responseContentType,
163+
streamingHttpClient.executionContext().bufferAllocator(),
164+
readGrpcMessageEncodingRaw(response.headers(), deserializerIdentity, deserializers,
165+
GrpcStreamingDeserializer::messageEncoding), httpRequest.requestTarget());
166+
} catch (Throwable t) {
167+
toSource(response.messageBody()).subscribe(CancelImmediatelySubscriber.INSTANCE);
168+
return Publisher.failed(GrpcStatusException.fromThrowable(t));
169+
}
162170
})
163171
.onErrorMap(GrpcStatusException::fromThrowable);
164172
};
@@ -291,11 +299,16 @@ public <Req, Resp> BlockingStreamingClientCall<Req, Resp> newBlockingStreamingCa
291299
streamingHttpClient.executionContext().bufferAllocator()));
292300
try {
293301
final BlockingStreamingHttpResponse response = client.request(httpRequest);
294-
extractResponseContext(response, metadata);
295-
return validateResponseAndGetPayload(response.toStreamingResponse(), responseContentType,
296-
client.executionContext().bufferAllocator(), readGrpcMessageEncodingRaw(
297-
response.headers(), deserializerIdentity, deserializers,
298-
GrpcStreamingDeserializer::messageEncoding), httpRequest.requestTarget()).toIterable();
302+
try {
303+
extractResponseContext(response, metadata);
304+
return validateResponseAndGetPayload(response.toStreamingResponse(), responseContentType,
305+
client.executionContext().bufferAllocator(), readGrpcMessageEncodingRaw(response.headers(),
306+
deserializerIdentity, deserializers, GrpcStreamingDeserializer::messageEncoding),
307+
httpRequest.requestTarget()).toIterable();
308+
} catch (Throwable t) {
309+
response.messageBody().iterator().close();
310+
throw GrpcStatusException.fromThrowable(t);
311+
}
299312
} catch (Throwable cause) {
300313
throw GrpcStatusException.fromThrowable(cause);
301314
}

servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ProtocolCompatibilityTest.java

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,27 @@
5353
import io.servicetalk.grpc.netty.CompatProto.Compat.ServiceFactory;
5454
import io.servicetalk.grpc.netty.CompatProto.RequestContainer.CompatRequest;
5555
import io.servicetalk.grpc.netty.CompatProto.ResponseContainer.CompatResponse;
56+
import io.servicetalk.http.api.FilterableStreamingHttpClient;
57+
import io.servicetalk.http.api.HttpExecutionStrategies;
58+
import io.servicetalk.http.api.HttpExecutionStrategy;
5659
import io.servicetalk.http.api.HttpResponse;
5760
import io.servicetalk.http.api.HttpResponseStatus;
5861
import io.servicetalk.http.api.HttpServerBuilder;
5962
import io.servicetalk.http.api.HttpServerContext;
6063
import io.servicetalk.http.api.HttpServiceContext;
6164
import io.servicetalk.http.api.SingleAddressHttpClientBuilder;
6265
import io.servicetalk.http.api.StreamingHttpClient;
66+
import io.servicetalk.http.api.StreamingHttpClientFilter;
67+
import io.servicetalk.http.api.StreamingHttpClientFilterFactory;
6368
import io.servicetalk.http.api.StreamingHttpRequest;
69+
import io.servicetalk.http.api.StreamingHttpRequester;
6470
import io.servicetalk.http.api.StreamingHttpResponse;
6571
import io.servicetalk.http.api.StreamingHttpResponseFactory;
6672
import io.servicetalk.http.api.StreamingHttpService;
6773
import io.servicetalk.http.api.StreamingHttpServiceFilter;
6874
import io.servicetalk.http.netty.HttpClients;
6975
import io.servicetalk.http.netty.HttpServers;
76+
import io.servicetalk.http.utils.BeforeFinallyHttpOperator;
7077
import io.servicetalk.test.resources.DefaultTestCerts;
7178
import io.servicetalk.transport.api.ClientSslConfigBuilder;
7279
import io.servicetalk.transport.api.ServerContext;
@@ -91,6 +98,7 @@
9198
import io.grpc.stub.StreamObserver;
9299
import io.netty.handler.ssl.SslContext;
93100
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
101+
import org.junit.jupiter.api.AfterEach;
94102
import org.junit.jupiter.api.function.ThrowingSupplier;
95103
import org.junit.jupiter.params.ParameterizedTest;
96104
import org.junit.jupiter.params.provider.Arguments;
@@ -200,6 +208,15 @@ public SocketAddress listenAddress() {
200208
private static final String CUSTOM_ERROR_MESSAGE = "custom error message";
201209
private static final DeliberateException SERVER_PROCESSED_TOKEN = new DeliberateException();
202210
private static final Duration DEFAULT_DEADLINE = ofMillis(100);
211+
private static final boolean[] TRUE_FALSE = {true, false};
212+
private static final String[] COMPRESSION = {"gzip", "identity", null};
213+
214+
private final ResponseLeakValidator responseLeakValidator = new ResponseLeakValidator();
215+
216+
@AfterEach
217+
void finalChecks() {
218+
responseLeakValidator.assertNoPendingRequests();
219+
}
203220

204221
private enum ErrorMode {
205222
NONE,
@@ -211,9 +228,6 @@ private enum ErrorMode {
211228
STATUS_IN_RESPONSE
212229
}
213230

214-
private static final boolean[] TRUE_FALSE = {true, false};
215-
private static final String[] COMPRESSION = {"gzip", "identity", null};
216-
217231
private static Collection<Arguments> sslStreamingAndCompressionParams() {
218232
List<Arguments> args = new ArrayList<>();
219233
for (boolean ssl : TRUE_FALSE) {
@@ -1217,16 +1231,21 @@ private static CompatResponse computeResponse(final int value) {
12171231
.build();
12181232
}
12191233

1220-
private static CompatClient serviceTalkClient(final SocketAddress serverAddress, final boolean ssl,
1221-
@Nullable final String compression,
1222-
@Nullable final Duration timeout) {
1234+
private CompatClient serviceTalkClient(final SocketAddress serverAddress, final boolean ssl,
1235+
@Nullable final String compression,
1236+
@Nullable final Duration timeout) {
12231237
final GrpcClientBuilder<InetSocketAddress, InetSocketAddress> builder =
12241238
GrpcClients.forResolvedAddress((InetSocketAddress) serverAddress);
1225-
if (ssl) {
1226-
builder.initializeHttp(b -> b.sslConfig(new ClientSslConfigBuilder(
1227-
DefaultTestCerts::loadServerCAPem).peerHost(serverPemHostname()).build()));
1228-
}
1229-
if (null != timeout) {
1239+
1240+
builder.initializeHttp(b -> {
1241+
if (ssl) {
1242+
b.sslConfig(new ClientSslConfigBuilder(
1243+
DefaultTestCerts::loadServerCAPem).peerHost(serverPemHostname()).build());
1244+
}
1245+
b.appendClientFilter(responseLeakValidator);
1246+
});
1247+
1248+
if (timeout != null) {
12301249
builder.defaultTimeout(timeout);
12311250
}
12321251
return builder.build(new Compat.ClientFactory()
@@ -1854,4 +1873,34 @@ private CompatResponse response(final int value) throws Exception {
18541873
return computeResponse(value);
18551874
}
18561875
}
1876+
1877+
private static final class ResponseLeakValidator implements StreamingHttpClientFilterFactory {
1878+
1879+
private final AtomicInteger pendingRequests = new AtomicInteger();
1880+
1881+
@Override
1882+
public StreamingHttpClientFilter create(FilterableStreamingHttpClient client) {
1883+
return new StreamingHttpClientFilter(client) {
1884+
@Override
1885+
protected Single<StreamingHttpResponse> request(StreamingHttpRequester delegate,
1886+
StreamingHttpRequest request) {
1887+
return Single.defer(() -> {
1888+
pendingRequests.incrementAndGet();
1889+
return delegate.request(request)
1890+
.liftSync(new BeforeFinallyHttpOperator(pendingRequests::decrementAndGet))
1891+
.shareContextOnSubscribe();
1892+
});
1893+
}
1894+
};
1895+
}
1896+
1897+
@Override
1898+
public HttpExecutionStrategy requiredOffloads() {
1899+
return HttpExecutionStrategies.offloadNone();
1900+
}
1901+
1902+
void assertNoPendingRequests() {
1903+
assertThat("Detected pending requests, possible response leak", pendingRequests.get(), is(0));
1904+
}
1905+
}
18571906
}

0 commit comments

Comments
 (0)