Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ed2f7d4
Initial version
attilapiros Apr 18, 2019
c6d81ec
applying review comments 1.0
attilapiros Apr 30, 2019
528b05a
java checkstyle fix
attilapiros Apr 30, 2019
98d2cbc
java checkstyle fix 2.0
attilapiros Apr 30, 2019
d5a3149
java checkstyle fix 3.0
attilapiros Apr 30, 2019
00d456a
SyncBlockTransferClient
attilapiros May 1, 2019
37ca716
Handling of deleted files
attilapiros May 1, 2019
fd0b107
mostly indentation fixes
attilapiros May 3, 2019
e7e0539
cleanup of non shuffle service served files
attilapiros May 3, 2019
7f2fc12
fix line length
attilapiros May 3, 2019
6b08379
applying review comments
attilapiros May 6, 2019
767f5f7
Applying review comments (from Imran)
attilapiros May 7, 2019
2aa2097
adding back MEMORY_AND_DISK
attilapiros May 7, 2019
928be2c
applying review comments of Imran v2.0
attilapiros May 8, 2019
79ed69a
applying review comments of Imran v3.0
attilapiros May 13, 2019
f8beff1
extend External Shuffle Service with remove blocks
attilapiros May 16, 2019
d42eeeb
fix
attilapiros May 16, 2019
d6ca9c8
fix2
attilapiros May 16, 2019
39c9914
introduce spark.shuffle.service.fetch.rdd.enabled
attilapiros May 17, 2019
16ab64f
fix checkstyle
attilapiros May 17, 2019
43fac8b
keep the old logic: sending RemoveRdd to every live executor
attilapiros May 18, 2019
5a1a15a
javastyle fix
attilapiros May 18, 2019
e3adc05
applying review comments of Vanzin
attilapiros May 21, 2019
bf9ec92
applying review comments
attilapiros May 22, 2019
faa583f
fixing: test NITs
attilapiros May 23, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ protected void channelRead0(
try {
streamManager.checkAuthorization(client, msg.streamChunkId.streamId);
buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex);
if (buf == null) {
throw new IllegalStateException("Chunk was not found");
}
} catch (Exception e) {
logger.error(String.format("Error opening block %s for request from %s",
msg.streamChunkId, getRemoteAddress(channel)), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ public void connectionTerminated(Channel channel) {

// Release all remaining buffers.
while (state.buffers.hasNext()) {
state.buffers.next().release();
ManagedBuffer buffer = state.buffers.next();
if (buffer != null) {
buffer.release();
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import io.netty.channel.Channel;
import org.apache.spark.network.server.ChunkFetchRequestHandler;
import org.junit.Assert;
import org.junit.Test;

import static org.mockito.Mockito.*;
Expand All @@ -45,9 +46,8 @@ public void handleChunkFetchRequest() throws Exception {
Channel channel = mock(Channel.class);
ChannelHandlerContext context = mock(ChannelHandlerContext.class);
when(context.channel())
.thenAnswer(invocationOnMock0 -> {
return channel;
});
.thenAnswer(invocationOnMock0 -> channel);

List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs =
new ArrayList<>();
when(channel.writeAndFlush(any()))
Expand All @@ -62,6 +62,7 @@ public void handleChunkFetchRequest() throws Exception {
List<ManagedBuffer> managedBuffers = new ArrayList<>();
managedBuffers.add(new TestManagedBuffer(10));
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(null);
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
Expand All @@ -71,31 +72,40 @@ public void handleChunkFetchRequest() throws Exception {

RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0));
requestHandler.channelRead(context, request0);
assert responseAndPromisePairs.size() == 1;
assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() ==
managedBuffers.get(0);
Assert.assertEquals(1, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess);
Assert.assertEquals(managedBuffers.get(0),
((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body());

RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1));
requestHandler.channelRead(context, request1);
assert responseAndPromisePairs.size() == 2;
assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() ==
managedBuffers.get(1);
Assert.assertEquals(2, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess);
Assert.assertEquals(managedBuffers.get(1),
((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body());

// Finish flushing the response for request0.
responseAndPromisePairs.get(0).getRight().finish(true);

RequestMessage request2 = new ChunkFetchRequest(new StreamChunkId(streamId, 2));
requestHandler.channelRead(context, request2);
assert responseAndPromisePairs.size() == 3;
assert responseAndPromisePairs.get(2).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(2).getLeft())).body() ==
managedBuffers.get(2);
Assert.assertEquals(3, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(2).getLeft() instanceof ChunkFetchFailure);
ChunkFetchFailure chunkFetchFailure =
((ChunkFetchFailure) (responseAndPromisePairs.get(2).getLeft()));
Assert.assertEquals("java.lang.IllegalStateException: Chunk was not found",
chunkFetchFailure.errorString.split("\\r?\\n")[0]);

RequestMessage request3 = new ChunkFetchRequest(new StreamChunkId(streamId, 3));
requestHandler.channelRead(context, request3);
Assert.assertEquals(4, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(3).getLeft() instanceof ChunkFetchSuccess);
Assert.assertEquals(managedBuffers.get(3),
((ChunkFetchSuccess) (responseAndPromisePairs.get(3).getLeft())).body());

RequestMessage request4 = new ChunkFetchRequest(new StreamChunkId(streamId, 4));
requestHandler.channelRead(context, request4);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;
Assert.assertEquals(4, responseAndPromisePairs.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;

import io.netty.channel.Channel;
import org.junit.Assert;
import org.junit.Test;

import static org.mockito.Mockito.*;
Expand All @@ -38,7 +39,7 @@
public class TransportRequestHandlerSuite {

@Test
public void handleStreamRequest() throws Exception {
public void handleStreamRequest() {
RpcHandler rpcHandler = new NoOpRpcHandler();
OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
Channel channel = mock(Channel.class);
Expand All @@ -56,48 +57,56 @@ public void handleStreamRequest() throws Exception {
List<ManagedBuffer> managedBuffers = new ArrayList<>();
managedBuffers.add(new TestManagedBuffer(10));
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(null);
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);

assert streamManager.numStreamStates() == 1;
Assert.assertEquals(1, streamManager.numStreamStates());

TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);

RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0));
requestHandler.handle(request0);
assert responseAndPromisePairs.size() == 1;
assert responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse;
assert ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body() ==
managedBuffers.get(0);
Assert.assertEquals(1, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse);
Assert.assertEquals(managedBuffers.get(0),
((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body());

RequestMessage request1 = new StreamRequest(String.format("%d_%d", streamId, 1));
requestHandler.handle(request1);
assert responseAndPromisePairs.size() == 2;
assert responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse;
assert ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body() ==
managedBuffers.get(1);
Assert.assertEquals(2, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse);
Assert.assertEquals(managedBuffers.get(1),
((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body());

// Finish flushing the response for request0.
responseAndPromisePairs.get(0).getRight().finish(true);

RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
StreamRequest request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
requestHandler.handle(request2);
assert responseAndPromisePairs.size() == 3;
assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse;
assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() ==
managedBuffers.get(2);
Assert.assertEquals(3, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(2).getLeft() instanceof StreamFailure);
Assert.assertEquals(String.format("Stream '%s' was not found.", request2.streamId),
((StreamFailure) (responseAndPromisePairs.get(2).getLeft())).error);

// Request3 will trigger the close of channel, because the number of max chunks being
// transferred is 2;
RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3));
requestHandler.handle(request3);
Assert.assertEquals(4, responseAndPromisePairs.size());
Assert.assertTrue(responseAndPromisePairs.get(3).getLeft() instanceof StreamResponse);
Assert.assertEquals(managedBuffers.get(3),
((StreamResponse) (responseAndPromisePairs.get(3).getLeft())).body());

// Request4 will trigger the close of channel, because the number of max chunks being
// transferred is 2;
RequestMessage request4 = new StreamRequest(String.format("%d_%d", streamId, 4));
requestHandler.handle(request4);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;
Assert.assertEquals(4, responseAndPromisePairs.size());

streamManager.connectionTerminated(channel);
assert streamManager.numStreamStates() == 0;
Assert.assertEquals(0, streamManager.numStreamStates());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.List;

import io.netty.channel.Channel;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

Expand All @@ -29,23 +31,69 @@

public class OneForOneStreamManagerSuite {

List<ManagedBuffer> managedBuffersToRelease = new ArrayList<>();

@After
public void tearDown() {
managedBuffersToRelease.forEach(managedBuffer -> managedBuffer.release());
managedBuffersToRelease.clear();
}

private ManagedBuffer getChunk(OneForOneStreamManager manager, long streamId, int chunkIndex) {
ManagedBuffer chunk = manager.getChunk(streamId, chunkIndex);
if (chunk != null) {
managedBuffersToRelease.add(chunk);
}
return chunk;
}

@Test
public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
public void testMissingChunk() {
OneForOneStreamManager manager = new OneForOneStreamManager();
List<ManagedBuffer> buffers = new ArrayList<>();
TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
TestManagedBuffer buffer3 = Mockito.spy(new TestManagedBuffer(20));

buffers.add(buffer1);
// the nulls here are to simulate a file which goes missing before being read,
// just as a defensive measure
buffers.add(null);
buffers.add(buffer2);
buffers.add(null);
buffers.add(buffer3);

Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerStream("appId", buffers.iterator(), dummyChannel);
assert manager.numStreamStates() == 1;
long streamId = manager.registerStream("appId", buffers.iterator(), dummyChannel);
Assert.assertEquals(1, manager.numStreamStates());
Assert.assertNotNull(getChunk(manager, streamId, 0));
Assert.assertNull(getChunk(manager, streamId, 1));
Assert.assertNotNull(getChunk(manager, streamId, 2));
manager.connectionTerminated(dummyChannel);

// loaded buffers are not released yet as in production a MangedBuffer returned by getChunk()
// would only be released by Netty after it is written to the network
Mockito.verify(buffer1, Mockito.never()).release();
Mockito.verify(buffer2, Mockito.never()).release();
Mockito.verify(buffer3, Mockito.times(1)).release();
}

@Test
public void managedBuffersAreFreedWhenConnectionIsClosed() {
OneForOneStreamManager manager = new OneForOneStreamManager();
List<ManagedBuffer> buffers = new ArrayList<>();
TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
buffers.add(buffer1);
buffers.add(buffer2);

Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerStream("appId", buffers.iterator(), dummyChannel);
Assert.assertEquals(1, manager.numStreamStates());
manager.connectionTerminated(dummyChannel);

Mockito.verify(buffer1, Mockito.times(1)).release();
Mockito.verify(buffer2, Mockito.times(1)).release();
assert manager.numStreamStates() == 0;
Assert.assertEquals(0, manager.numStreamStates());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle;

public class Constants {

public static final String SHUFFLE_SERVICE_FETCH_RDD_ENABLED =
"spark.shuffle.service.fetch.rdd.enabled";
}
Loading