Skip to content

Commit 9bf2120

Browse files
Marcelo VanzinAndrew Or
authored andcommitted
[SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer.
This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. Author: Marcelo Vanzin <[email protected]> Closes #9987 from vanzin/SPARK-12007.
1 parent 0a46e43 commit 9bf2120

50 files changed

Lines changed: 589 additions & 307 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.deploy.mesos
1919

2020
import java.net.SocketAddress
21+
import java.nio.ByteBuffer
2122

2223
import scala.collection.mutable
2324

@@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo
5657
}
5758
}
5859
connectedApps(address) = appId
59-
callback.onSuccess(new Array[Byte](0))
60+
callback.onSuccess(ByteBuffer.allocate(0))
6061
case _ => super.handleMessage(message, client, callback)
6162
}
6263
}

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class NettyBlockRpcServer(
4747

4848
override def receive(
4949
client: TransportClient,
50-
messageBytes: Array[Byte],
50+
rpcMessage: ByteBuffer,
5151
responseContext: RpcResponseCallback): Unit = {
52-
val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
52+
val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
5353
logTrace(s"Received request: $message")
5454

5555
message match {
@@ -58,15 +58,15 @@ class NettyBlockRpcServer(
5858
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
5959
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
6060
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
61-
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
61+
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
6262

6363
case uploadBlock: UploadBlock =>
6464
// StorageLevel is serialized as bytes using our JavaSerializer.
6565
val level: StorageLevel =
6666
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
6767
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
6868
blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
69-
responseContext.onSuccess(new Array[Byte](0))
69+
responseContext.onSuccess(ByteBuffer.allocate(0))
7070
}
7171
}
7272

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.network.netty
1919

20+
import java.nio.ByteBuffer
21+
2022
import scala.collection.JavaConverters._
2123
import scala.concurrent.{Future, Promise}
2224

@@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
133135
data
134136
}
135137

136-
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
138+
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
137139
new RpcResponseCallback {
138-
override def onSuccess(response: Array[Byte]): Unit = {
140+
override def onSuccess(response: ByteBuffer): Unit = {
139141
logTrace(s"Successfully uploaded block $blockId")
140142
result.success((): Unit)
141143
}

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv(
241241
promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
242242
}
243243

244-
private[netty] def serialize(content: Any): Array[Byte] = {
245-
val buffer = javaSerializerInstance.serialize(content)
246-
java.util.Arrays.copyOfRange(
247-
buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
244+
private[netty] def serialize(content: Any): ByteBuffer = {
245+
javaSerializerInstance.serialize(content)
248246
}
249247

250-
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
248+
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
251249
NettyRpcEnv.currentClient.withValue(client) {
252250
deserialize { () =>
253-
javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
251+
javaSerializerInstance.deserialize[T](bytes)
254252
}
255253
}
256254
}
@@ -557,20 +555,20 @@ private[netty] class NettyRpcHandler(
557555

558556
override def receive(
559557
client: TransportClient,
560-
message: Array[Byte],
558+
message: ByteBuffer,
561559
callback: RpcResponseCallback): Unit = {
562560
val messageToDispatch = internalReceive(client, message)
563561
dispatcher.postRemoteMessage(messageToDispatch, callback)
564562
}
565563

566564
override def receive(
567565
client: TransportClient,
568-
message: Array[Byte]): Unit = {
566+
message: ByteBuffer): Unit = {
569567
val messageToDispatch = internalReceive(client, message)
570568
dispatcher.postOneWayMessage(messageToDispatch)
571569
}
572570

573-
private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = {
571+
private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = {
574572
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
575573
assert(addr != null)
576574
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)

core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.rpc.netty
1919

20+
import java.nio.ByteBuffer
2021
import java.util.concurrent.Callable
2122
import javax.annotation.concurrent.GuardedBy
2223

@@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage {
3435

3536
}
3637

37-
private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage
38+
private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage
3839
with Logging {
3940

4041
override def sendWith(client: TransportClient): Unit = {
@@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb
4849
}
4950

5051
private[netty] case class RpcOutboxMessage(
51-
content: Array[Byte],
52+
content: ByteBuffer,
5253
_onFailure: (Throwable) => Unit,
53-
_onSuccess: (TransportClient, Array[Byte]) => Unit)
54+
_onSuccess: (TransportClient, ByteBuffer) => Unit)
5455
extends OutboxMessage with RpcResponseCallback {
5556

5657
private var client: TransportClient = _
@@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage(
7071
_onFailure(e)
7172
}
7273

73-
override def onSuccess(response: Array[Byte]): Unit = {
74+
override def onSuccess(response: ByteBuffer): Unit = {
7475
_onSuccess(client, response)
7576
}
7677

core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.rpc.netty
1919

2020
import java.net.InetSocketAddress
21+
import java.nio.ByteBuffer
2122

2223
import io.netty.channel.Channel
2324
import org.mockito.Mockito._
@@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
3233

3334
val env = mock(classOf[NettyRpcEnv])
3435
val sm = mock(classOf[StreamManager])
35-
when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
36+
when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
3637
.thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
3738

3839
test("receive") {

network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
package org.apache.spark.network.client;
1919

20+
import java.nio.ByteBuffer;
21+
2022
/**
2123
* Callback for the result of a single RPC. This will be invoked once with either success or
2224
* failure.
2325
*/
2426
public interface RpcResponseCallback {
2527
/** Successful serialized result from server. */
26-
void onSuccess(byte[] response);
28+
void onSuccess(ByteBuffer response);
2729

2830
/** Exception either propagated from server or raised on client side. */
2931
void onFailure(Throwable e);

network/common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.Closeable;
2121
import java.io.IOException;
2222
import java.net.SocketAddress;
23+
import java.nio.ByteBuffer;
2324
import java.util.UUID;
2425
import java.util.concurrent.ExecutionException;
2526
import java.util.concurrent.TimeUnit;
@@ -36,6 +37,7 @@
3637
import org.slf4j.Logger;
3738
import org.slf4j.LoggerFactory;
3839

40+
import org.apache.spark.network.buffer.NioManagedBuffer;
3941
import org.apache.spark.network.protocol.ChunkFetchRequest;
4042
import org.apache.spark.network.protocol.OneWayMessage;
4143
import org.apache.spark.network.protocol.RpcRequest;
@@ -212,15 +214,15 @@ public void operationComplete(ChannelFuture future) throws Exception {
212214
* @param callback Callback to handle the RPC's reply.
213215
* @return The RPC's id.
214216
*/
215-
public long sendRpc(byte[] message, final RpcResponseCallback callback) {
217+
public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
216218
final String serverAddr = NettyUtils.getRemoteAddress(channel);
217219
final long startTime = System.currentTimeMillis();
218220
logger.trace("Sending RPC to {}", serverAddr);
219221

220222
final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
221223
handler.addRpcRequest(requestId, callback);
222224

223-
channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
225+
channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
224226
new ChannelFutureListener() {
225227
@Override
226228
public void operationComplete(ChannelFuture future) throws Exception {
@@ -249,12 +251,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
249251
* Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
250252
* a specified timeout for a response.
251253
*/
252-
public byte[] sendRpcSync(byte[] message, long timeoutMs) {
253-
final SettableFuture<byte[]> result = SettableFuture.create();
254+
public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
255+
final SettableFuture<ByteBuffer> result = SettableFuture.create();
254256

255257
sendRpc(message, new RpcResponseCallback() {
256258
@Override
257-
public void onSuccess(byte[] response) {
259+
public void onSuccess(ByteBuffer response) {
258260
result.set(response);
259261
}
260262

@@ -279,8 +281,8 @@ public void onFailure(Throwable e) {
279281
*
280282
* @param message The message to send.
281283
*/
282-
public void send(byte[] message) {
283-
channel.writeAndFlush(new OneWayMessage(message));
284+
public void send(ByteBuffer message) {
285+
channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
284286
}
285287

286288
/**

network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,19 @@ public void exceptionCaught(Throwable cause) {
136136
}
137137

138138
@Override
139-
public void handle(ResponseMessage message) {
139+
public void handle(ResponseMessage message) throws Exception {
140140
String remoteAddress = NettyUtils.getRemoteAddress(channel);
141141
if (message instanceof ChunkFetchSuccess) {
142142
ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
143143
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
144144
if (listener == null) {
145145
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
146146
resp.streamChunkId, remoteAddress);
147-
resp.body.release();
147+
resp.body().release();
148148
} else {
149149
outstandingFetches.remove(resp.streamChunkId);
150-
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
151-
resp.body.release();
150+
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
151+
resp.body().release();
152152
}
153153
} else if (message instanceof ChunkFetchFailure) {
154154
ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -166,10 +166,14 @@ public void handle(ResponseMessage message) {
166166
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
167167
if (listener == null) {
168168
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
169-
resp.requestId, remoteAddress, resp.response.length);
169+
resp.requestId, remoteAddress, resp.body().size());
170170
} else {
171171
outstandingRpcs.remove(resp.requestId);
172-
listener.onSuccess(resp.response);
172+
try {
173+
listener.onSuccess(resp.body().nioByteBuffer());
174+
} finally {
175+
resp.body().release();
176+
}
173177
}
174178
} else if (message instanceof RpcFailure) {
175179
RpcFailure resp = (RpcFailure) message;
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.protocol;
19+
20+
import com.google.common.base.Objects;
21+
22+
import org.apache.spark.network.buffer.ManagedBuffer;
23+
24+
/**
25+
* Abstract class for messages which optionally contain a body kept in a separate buffer.
26+
*/
27+
public abstract class AbstractMessage implements Message {
28+
private final ManagedBuffer body;
29+
private final boolean isBodyInFrame;
30+
31+
protected AbstractMessage() {
32+
this(null, false);
33+
}
34+
35+
protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
36+
this.body = body;
37+
this.isBodyInFrame = isBodyInFrame;
38+
}
39+
40+
@Override
41+
public ManagedBuffer body() {
42+
return body;
43+
}
44+
45+
@Override
46+
public boolean isBodyInFrame() {
47+
return isBodyInFrame;
48+
}
49+
50+
protected boolean equals(AbstractMessage other) {
51+
return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
52+
}
53+
54+
}

0 commit comments

Comments
 (0)