diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0..a5c5d0317eef 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -199,7 +199,7 @@ public void handle(ResponseMessage message) throws Exception { } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); - if (listener == null) { + if (listener == null && resp.requestId != RpcFailure.EMPTY_REQUEST_ID) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.errorString); } else { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index a76624ef5dc9..65d913e75f5e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -22,6 +22,8 @@ /** Response to {@link RpcRequest} for a failed RPC. */ public final class RpcFailure extends AbstractMessage implements ResponseMessage { + public static final long EMPTY_REQUEST_ID = Long.MIN_VALUE; + public final long requestId; public final String errorString; diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b..b9ac89c5a14d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.InvalidClassException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -206,6 +207,11 @@ public void onFailure(Throwable e) { private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); + } catch (InvalidClassException ice) { + final String msg = "There is probably a version mismatch between client and server: "; + respond(new RpcFailure(RpcFailure.EMPTY_REQUEST_ID, msg + + Throwables.getStackTraceAsString(ice))); + logger.error(msg, ice); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); } finally { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 2656cbee95a2..14cea9e6083c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.InvalidClassException; import java.util.ArrayList; import java.util.List; @@ -27,6 +28,8 @@ import io.netty.util.concurrent.GenericFutureListener; import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -100,6 +103,41 @@ public void handleFetchRequestAndStreamRequest() throws Exception { assert responseAndPromisePairs.size() == 3; } + @Test + public void handleOneWayMessageWithWrongSerialVersionUID() throws Exception { + RpcHandler rpcHandler = new NoOpRpcHandler(); + Channel channel = mock(Channel.class); + List> responseAndPromisePairs = + new ArrayList<>(); + + when(channel.writeAndFlush(any())) + .thenAnswer(invocationOnMock -> { + Object response = invocationOnMock.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + TransportClient reverseClient = mock(TransportClient.class); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, + rpcHandler, 2L); + + // req.body().nioByteBuffer() is the method that throws the InvalidClassException + // with wrong svUID, so let's mock it + ManagedBuffer body = mock(ManagedBuffer.class); + when(body.nioByteBuffer()).thenThrow(new InvalidClassException("test - wrong version")); + RequestMessage msg = new OneWayMessage(body); + + requestHandler.handle(msg); + + assertEquals(responseAndPromisePairs.size(), 1); + assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof RpcFailure); + assertEquals(((RpcFailure) responseAndPromisePairs.get(0).getLeft()).requestId, + RpcFailure.EMPTY_REQUEST_ID); + + responseAndPromisePairs.get(0).getRight().finish(true); + } + private class ExtendedChannelPromise extends DefaultChannelPromise { private List>> listeners = new ArrayList<>();