From 1c8847494c29d4b51182ecfeebb5cc85e000e7a1 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 7 Feb 2017 14:30:42 -0800 Subject: [PATCH 1/2] Avoid using ExecutorClassLoader to load Netty generated classes --- .../spark/network/TransportContext.java | 22 ++++++++++++++----- .../network/protocol/MessageDecoder.java | 6 +++++ .../network/protocol/MessageEncoder.java | 6 +++++ .../server/TransportChannelHandler.java | 11 +++++----- .../apache/spark/network/ProtocolSuite.java | 8 +++---- .../scala/org/apache/spark/util/Utils.scala | 18 +++++---------- 6 files changed, 43 insertions(+), 28 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 5b69e2bb0354..37ba543380f0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -62,8 +62,20 @@ public class TransportContext { private final RpcHandler rpcHandler; private final boolean closeIdleConnections; - private final MessageEncoder encoder; - private final MessageDecoder decoder; + /** + * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created + * before switching the current context class loader to ExecutorClassLoader. + * + * Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the + * implementation calls "Class.forName" to check if this calls is already generated. If the + * following two objects are created in "ExecutorClassLoader.findClass", it will cause + * "ClassCircularityError". This is because loading this Netty generated class will call + * "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use + * RPC to load it and cause to load the non-exist matcher class again. JVM will report + * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714) + */ + private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; + private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this(conf, rpcHandler, false); @@ -75,8 +87,6 @@ public TransportContext( boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; - this.encoder = new MessageEncoder(); - this.decoder = new MessageDecoder(); this.closeIdleConnections = closeIdleConnections; } @@ -135,9 +145,9 @@ public TransportChannelHandler initializePipeline( try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() - .addLast("encoder", encoder) + .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) - .addLast("decoder", decoder) + .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index f0956438ade2..cfcedda18b74 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -35,6 +35,12 @@ public final class MessageDecoder extends MessageToMessageDecoder { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + public static final MessageDecoder INSTANCE = new MessageDecoder(); + + private MessageDecoder() { + super(); + } + @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 276f16637efc..58a4a689781d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -35,6 +35,12 @@ public final class MessageEncoder extends MessageToMessageEncoder { private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + public static final MessageEncoder INSTANCE = new MessageEncoder(); + + private MessageEncoder() { + super(); + } + /*** * Encodes a Message by invoking its encode() method. For non-data messages, we will add one * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index c6ccae18b5e0..56782a832787 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -18,7 +18,7 @@ package org.apache.spark.network.server; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; @@ -26,7 +26,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -48,7 +47,7 @@ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * timeout if the client is continuously sending but getting no responses, for simplicity. */ -public class TransportChannelHandler extends SimpleChannelInboundHandler { +public class TransportChannelHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; @@ -114,11 +113,13 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); - } else { + } else if (request instanceof ResponseMessage) { responseHandler.handle((ResponseMessage) request); + } else { + ctx.fireChannelRead(request); } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 6c8dd742f4b6..bb1c40c4b0e0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -49,11 +49,11 @@ public class ProtocolSuite { private void testServerToClient(Message msg) { EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); + MessageEncoder.INSTANCE); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); + NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeInbound(serverChannel.readOutbound()); @@ -65,11 +65,11 @@ private void testServerToClient(Message msg) { private void testClientToServer(Message msg) { EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); + MessageEncoder.INSTANCE); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); + NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeInbound(clientChannel.readOutbound()); diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c1d331b9ab1..90580f426a03 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2599,14 +2599,10 @@ private[spark] object Utils extends Logging { private[util] object CallerContext extends Logging { val callerContextSupported: Boolean = { - SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { + SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", true) && { try { - // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in - // master Maven build, so do not use it before resolving SPARK-17714. - // scalastyle:off classforname - Class.forName("org.apache.hadoop.ipc.CallerContext") - Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname + Utils.classForName("org.apache.hadoop.ipc.CallerContext") + Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") true } catch { case _: ClassNotFoundException => @@ -2681,12 +2677,8 @@ private[spark] class CallerContext( def setCurrentContext(): Unit = { if (CallerContext.callerContextSupported) { try { - // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in - // master Maven build, so do not use it before resolving SPARK-17714. - // scalastyle:off classforname - val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") - val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") val builderInst = builder.getConstructor(classOf[String]).newInstance(context) val hdfsContext = builder.getMethod("build").invoke(builderInst) callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) From 7fcb31788a24ca48c988b2ea03bb803112f484a8 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 9 Feb 2017 17:14:51 -0800 Subject: [PATCH 2/2] Restore the default value --- .../org/apache/spark/network/protocol/MessageDecoder.java | 4 +--- .../org/apache/spark/network/protocol/MessageEncoder.java | 4 +--- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index cfcedda18b74..39a7495828a8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -37,9 +37,7 @@ public final class MessageDecoder extends MessageToMessageDecoder { public static final MessageDecoder INSTANCE = new MessageDecoder(); - private MessageDecoder() { - super(); - } + private MessageDecoder() {} @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 58a4a689781d..997f74e1a21b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -37,9 +37,7 @@ public final class MessageEncoder extends MessageToMessageEncoder { public static final MessageEncoder INSTANCE = new MessageEncoder(); - private MessageEncoder() { - super(); - } + private MessageEncoder() {} /*** * Encodes a Message by invoking its encode() method. For non-data messages, we will add one diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 90580f426a03..626fbfd38273 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2599,7 +2599,7 @@ private[spark] object Utils extends Logging { private[util] object CallerContext extends Logging { val callerContextSupported: Boolean = { - SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", true) && { + SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { try { Utils.classForName("org.apache.hadoop.ipc.CallerContext") Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")