Skip to content

Commit d4e0352

Browse files
committed
simplify code
1 parent bd32400 commit d4e0352

7 files changed

Lines changed: 74 additions & 33 deletions

File tree

common/network-common/src/main/java/org/apache/spark/network/TransportContext.java

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public TransportContext(
123123

124124
if (conf.getModuleName() != null &&
125125
conf.getModuleName().equalsIgnoreCase("shuffle") &&
126-
!isClientOnly) {
126+
!isClientOnly && conf.separateChunkFetchRequest()) {
127127
chunkFetchWorkers = NettyUtils.createEventLoop(
128128
IOMode.valueOf(conf.ioMode()),
129129
conf.chunkFetchHandlerThreads(),
@@ -187,8 +187,6 @@ public TransportChannelHandler initializePipeline(
187187
RpcHandler channelRpcHandler) {
188188
try {
189189
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
190-
ChunkFetchRequestHandler chunkFetchHandler =
191-
createChunkFetchHandler(channelHandler, channelRpcHandler);
192190
ChannelPipeline pipeline = channel.pipeline()
193191
.addLast("encoder", ENCODER)
194192
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
@@ -200,6 +198,9 @@ public TransportChannelHandler initializePipeline(
200198
.addLast("handler", channelHandler);
201199
// Use a separate EventLoopGroup to handle ChunkFetchRequest messages for shuffle rpcs.
202200
if (chunkFetchWorkers != null) {
201+
ChunkFetchRequestHandler chunkFetchHandler = new ChunkFetchRequestHandler(
202+
channelHandler.getClient(), rpcHandler.getStreamManager(),
203+
conf.maxChunksBeingTransferred(), true /* syncModeEnabled */);
203204
pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", chunkFetchHandler);
204205
}
205206
return channelHandler;
@@ -217,19 +218,19 @@ public TransportChannelHandler initializePipeline(
217218
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
218219
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
219220
TransportClient client = new TransportClient(channel, responseHandler);
221+
boolean separateChunkFetchRequest = conf.separateChunkFetchRequest();
222+
ChunkFetchRequestHandler chunkFetchRequestHandler;
223+
if (!separateChunkFetchRequest) {
224+
chunkFetchRequestHandler = new ChunkFetchRequestHandler(
225+
client, rpcHandler.getStreamManager(),
226+
conf.maxChunksBeingTransferred(), false /* syncModeEnabled */);
227+
} else {
228+
chunkFetchRequestHandler = null;
229+
}
220230
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
221-
rpcHandler, conf.maxChunksBeingTransferred());
231+
rpcHandler, conf.maxChunksBeingTransferred(), chunkFetchRequestHandler);
222232
return new TransportChannelHandler(client, responseHandler, requestHandler,
223-
conf.connectionTimeoutMs(), closeIdleConnections, this);
224-
}
225-
226-
/**
227-
* Creates the dedicated ChannelHandler for ChunkFetchRequest messages.
228-
*/
229-
private ChunkFetchRequestHandler createChunkFetchHandler(TransportChannelHandler channelHandler,
230-
RpcHandler rpcHandler) {
231-
return new ChunkFetchRequestHandler(channelHandler.getClient(),
232-
rpcHandler.getStreamManager(), conf.maxChunksBeingTransferred());
233+
conf.connectionTimeoutMs(), separateChunkFetchRequest, closeIdleConnections, this);
233234
}
234235

235236
public TransportConf getConf() { return conf; }

common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,17 @@ public class ChunkFetchRequestHandler extends SimpleChannelInboundHandler<ChunkF
5555
private final StreamManager streamManager;
5656
/** The max number of chunks being transferred and not finished yet. */
5757
private final long maxChunksBeingTransferred;
58+
private final boolean syncModeEnabled;
5859

5960
public ChunkFetchRequestHandler(
6061
TransportClient client,
6162
StreamManager streamManager,
62-
Long maxChunksBeingTransferred) {
63+
Long maxChunksBeingTransferred,
64+
boolean syncModeEnabled) {
6365
this.client = client;
6466
this.streamManager = streamManager;
6567
this.maxChunksBeingTransferred = maxChunksBeingTransferred;
68+
this.syncModeEnabled = syncModeEnabled;
6669
}
6770

6871
@Override
@@ -76,6 +79,11 @@ protected void channelRead0(
7679
ChannelHandlerContext ctx,
7780
final ChunkFetchRequest msg) throws Exception {
7881
Channel channel = ctx.channel();
82+
processFetchRequest(channel, msg);
83+
}
84+
85+
public void processFetchRequest(
86+
final Channel channel, final ChunkFetchRequest msg) throws Exception {
7987
if (logger.isTraceEnabled()) {
8088
logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel),
8189
msg.streamChunkId);
@@ -112,19 +120,26 @@ protected void channelRead0(
112120
* channel will be handled by the EventLoop the channel is registered to. So even
113121
* though we are processing the ChunkFetchRequest in a separate thread pool, the actual I/O,
114122
* which is the potentially blocking call that could deplete server handler threads, is still
115-
* being processed by TransportServer's default EventLoopGroup. In order to throttle the max
116-
* number of threads that channel I/O for sending response to ChunkFetchRequest, the thread
117-
* calling channel.writeAndFlush will wait for the completion of sending response back to
118-
* client by invoking await(). This will throttle the rate at which threads from
119-
* ChunkFetchRequest dedicated EventLoopGroup submit channel I/O requests to TransportServer's
120-
* default EventLoopGroup, thus making sure that we can reserve some threads in
121-
* TransportServer's default EventLoopGroup for handling other RPC messages.
123+
* being processed by TransportServer's default EventLoopGroup.
124+
*
125+
* When syncModeEnabled is true, Spark will throttle the max number of threads that channel I/O
126+
* for sending response to ChunkFetchRequest, the thread calling channel.writeAndFlush will wait
127+
* for the completion of sending response back to client by invoking await(). This will throttle
128+
* the rate at which threads from ChunkFetchRequest dedicated EventLoopGroup submit channel I/O
129+
* requests to TransportServer's default EventLoopGroup, thus making sure that we can reserve
130+
* some threads in TransportServer's default EventLoopGroup for handling other RPC messages.
122131
*/
123132
private ChannelFuture respond(
124133
final Channel channel,
125134
final Encodable result) throws InterruptedException {
126135
final SocketAddress remoteAddress = channel.remoteAddress();
127-
return channel.writeAndFlush(result).await().addListener((ChannelFutureListener) future -> {
136+
ChannelFuture channelFuture;
137+
if (syncModeEnabled) {
138+
channelFuture = channel.writeAndFlush(result).await();
139+
} else {
140+
channelFuture = channel.writeAndFlush(result);
141+
}
142+
return channelFuture.addListener((ChannelFutureListener) future -> {
128143
if (future.isSuccess()) {
129144
logger.trace("Sent result {} to client {}", result, remoteAddress);
130145
} else {

common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,22 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
5858
private final TransportRequestHandler requestHandler;
5959
private final long requestTimeoutNs;
6060
private final boolean closeIdleConnections;
61+
private final boolean skipChunkFetchRequest;
6162
private final TransportContext transportContext;
6263

6364
public TransportChannelHandler(
6465
TransportClient client,
6566
TransportResponseHandler responseHandler,
6667
TransportRequestHandler requestHandler,
6768
long requestTimeoutMs,
69+
boolean skipChunkFetchRequest,
6870
boolean closeIdleConnections,
6971
TransportContext transportContext) {
7072
this.client = client;
7173
this.responseHandler = responseHandler;
7274
this.requestHandler = requestHandler;
7375
this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
76+
this.skipChunkFetchRequest = skipChunkFetchRequest;
7477
this.closeIdleConnections = closeIdleConnections;
7578
this.transportContext = transportContext;
7679
}
@@ -124,7 +127,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
124127
*/
125128
@Override
126129
public boolean acceptInboundMessage(Object msg) throws Exception {
127-
if (msg instanceof ChunkFetchRequest) {
130+
if (skipChunkFetchRequest && msg instanceof ChunkFetchRequest) {
128131
return false;
129132
} else {
130133
return super.acceptInboundMessage(msg);

common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,21 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
6262
/** The max number of chunks being transferred and not finished yet. */
6363
private final long maxChunksBeingTransferred;
6464

65+
/** The dedicated ChannelHandler for ChunkFetchRequest messages. */
66+
private final ChunkFetchRequestHandler chunkFetchRequestHandler;
67+
6568
public TransportRequestHandler(
6669
Channel channel,
6770
TransportClient reverseClient,
6871
RpcHandler rpcHandler,
69-
Long maxChunksBeingTransferred) {
72+
Long maxChunksBeingTransferred,
73+
ChunkFetchRequestHandler chunkFetchRequestHandler) {
7074
this.channel = channel;
7175
this.reverseClient = reverseClient;
7276
this.rpcHandler = rpcHandler;
7377
this.streamManager = rpcHandler.getStreamManager();
7478
this.maxChunksBeingTransferred = maxChunksBeingTransferred;
79+
this.chunkFetchRequestHandler = chunkFetchRequestHandler;
7580
}
7681

7782
@Override
@@ -97,8 +102,10 @@ public void channelInactive() {
97102
}
98103

99104
@Override
100-
public void handle(RequestMessage request) {
101-
if (request instanceof RpcRequest) {
105+
public void handle(RequestMessage request) throws Exception {
106+
if (request instanceof ChunkFetchRequest) {
107+
chunkFetchRequestHandler.processFetchRequest(channel, (ChunkFetchRequest) request);
108+
} else if (request instanceof RpcRequest) {
102109
processRpcRequest((RpcRequest) request);
103110
} else if (request instanceof OneWayMessage) {
104111
processOneWayMessage((OneWayMessage) request);

common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.network.util;
1919

2020
import java.util.Locale;
21+
import java.util.NoSuchElementException;
2122
import java.util.Properties;
2223

2324
import com.google.common.primitives.Ints;
@@ -316,7 +317,8 @@ public long maxChunksBeingTransferred() {
316317

317318
/**
318319
* Percentage of io.serverThreads used by netty to process ChunkFetchRequest.
319-
* Shuffle server will use a separate EventLoopGroup to process ChunkFetchRequest messages.
320+
* When the config `spark.shuffle.server.chunkFetchHandlerThreadsPercent` is set,
321+
* shuffle server will use a separate EventLoopGroup to process ChunkFetchRequest messages.
320322
* Although when calling the async writeAndFlush on the underlying channel to send
321323
* response back to client, the I/O on the channel is still being handled by
322324
* {@link org.apache.spark.network.server.TransportServer}'s default EventLoopGroup
@@ -339,12 +341,25 @@ public int chunkFetchHandlerThreads() {
339341
return 0;
340342
}
341343
int chunkFetchHandlerThreadsPercent =
342-
conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 100);
344+
Integer.parseInt(conf.get("spark.shuffle.server.chunkFetchHandlerThreadsPercent"));
343345
int threads =
344346
this.serverThreads() > 0 ? this.serverThreads() : 2 * NettyRuntime.availableProcessors();
345347
return (int) Math.ceil(threads * (chunkFetchHandlerThreadsPercent / 100.0));
346348
}
347349

350+
/**
351+
* Whether to use a separate EventLoopGroup to process ChunkFetchRequest messages, it is decided
352+
* by the config `spark.shuffle.server.chunkFetchHandlerThreadsPercent` is set or not.
353+
*/
354+
public boolean separateChunkFetchRequest() {
355+
try {
356+
conf.get("spark.shuffle.server.chunkFetchHandlerThreadsPercent");
357+
return true;
358+
} catch (NoSuchElementException e) {
359+
return false;
360+
}
361+
}
362+
348363
/**
349364
* Whether to use the old protocol while doing the shuffle block fetching.
350365
* It is only enabled while we need the compatibility in the scenario of new spark version

common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.List;
2323

2424
import io.netty.channel.Channel;
25-
import org.apache.spark.network.server.ChunkFetchRequestHandler;
2625
import org.junit.Assert;
2726
import org.junit.Test;
2827

@@ -33,6 +32,7 @@
3332
import org.apache.spark.network.buffer.ManagedBuffer;
3433
import org.apache.spark.network.client.TransportClient;
3534
import org.apache.spark.network.protocol.*;
35+
import org.apache.spark.network.server.ChunkFetchRequestHandler;
3636
import org.apache.spark.network.server.NoOpRpcHandler;
3737
import org.apache.spark.network.server.OneForOneStreamManager;
3838
import org.apache.spark.network.server.RpcHandler;
@@ -68,7 +68,7 @@ public void handleChunkFetchRequest() throws Exception {
6868
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
6969
TransportClient reverseClient = mock(TransportClient.class);
7070
ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient,
71-
rpcHandler.getStreamManager(), 2L);
71+
rpcHandler.getStreamManager(), 2L, false);
7272

7373
RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0));
7474
requestHandler.channelRead(context, request0);

common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
public class TransportRequestHandlerSuite {
4040

4141
@Test
42-
public void handleStreamRequest() {
42+
public void handleStreamRequest() throws Exception {
4343
RpcHandler rpcHandler = new NoOpRpcHandler();
4444
OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
4545
Channel channel = mock(Channel.class);
@@ -66,7 +66,7 @@ public void handleStreamRequest() {
6666

6767
TransportClient reverseClient = mock(TransportClient.class);
6868
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
69-
rpcHandler, 2L);
69+
rpcHandler, 2L, null);
7070

7171
RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0));
7272
requestHandler.handle(request0);

0 commit comments

Comments
 (0)