Skip to content

Commit 57fc4d7

Browse files
committed
Added test cases for BlockHeaderEncoder and BlockFetchingClientHandlerSuite.
1 parent 22623e9 commit 57fc4d7

6 files changed

Lines changed: 166 additions & 6 deletions

File tree

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import org.apache.spark.Logging
3333
* Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
3434
* Use [[BlockFetchingClientFactory]] to instantiate this client.
3535
*
36+
* The constructor blocks until a connection is successfully established.
37+
*
3638
* See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
3739
*
3840
* Concurrency: [[BlockFetchingClient]] is not thread safe and should not be shared.

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class BlockFetchingClientFactory(val conf: NettyConfig) {
8383
/**
8484
* Create a new BlockFetchingClient connecting to the given remote host / port.
8585
*
86+
* This blocks until a connection is successfully established.
87+
*
8688
* Concurrency: This method is safe to call from multiple threads.
8789
*/
8890
def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,19 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
4444
val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
4545
in.readBytes(blockIdBytes)
4646
val blockId = new String(blockIdBytes)
47-
val blockLen = totalLen - math.abs(blockIdLen) - 4
47+
val blockSize = totalLen - math.abs(blockIdLen) - 4
4848

4949
def server = ctx.channel.remoteAddress.toString
5050

5151
// blockIdLen is negative when it is an error message.
5252
if (blockIdLen < 0) {
53-
val errorMessageBytes = new Array[Byte](blockLen)
53+
val errorMessageBytes = new Array[Byte](blockSize)
5454
in.readBytes(errorMessageBytes)
5555
val errorMsg = new String(errorMessageBytes)
56-
logTrace(s"Received block $blockId ($blockLen B) with error $errorMsg from $server")
56+
logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
5757
blockFetchFailureCallback(blockId, errorMsg)
5858
} else {
59-
logTrace(s"Received block $blockId ($blockLen B) from $server")
59+
logTrace(s"Received block $blockId ($blockSize B) from $server")
6060
blockFetchSuccessCallback(blockId, new ReferenceCountedBuffer(in))
6161
}
6262
}

core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.network.netty
2020
import java.io.{RandomAccessFile, File}
2121
import java.nio.ByteBuffer
2222
import java.util.{Collections, HashSet}
23-
import java.util.concurrent.Semaphore
23+
import java.util.concurrent.{TimeUnit, Semaphore}
2424

2525
import scala.collection.JavaConversions._
2626

@@ -34,6 +34,9 @@ import org.apache.spark.network.netty.server.BlockServer
3434
import org.apache.spark.storage.{FileSegment, BlockDataProvider}
3535

3636

37+
/**
38+
* Test suite that makes sure the server and the client implementations share the same protocol.
39+
*/
3740
class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
3841

3942
val bufSize = 100000
@@ -108,7 +111,9 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
108111
sem.release()
109112
}
110113
)
111-
sem.acquire(blockIds.size)
114+
if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
115+
fail("Timeout getting response from the server")
116+
}
112117
client.close()
113118
(receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
114119
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty.client
19+
20+
import java.nio.ByteBuffer
21+
22+
import io.netty.buffer.Unpooled
23+
import io.netty.channel.embedded.EmbeddedChannel
24+
25+
import org.scalatest.FunSuite
26+
27+
28+
class BlockFetchingClientHandlerSuite extends FunSuite {
29+
30+
test("handling block data (successful fetch)") {
31+
val blockId = "test_block"
32+
val blockData = "blahblahblahblahblah"
33+
val totalLength = 4 + blockId.length + blockData.length
34+
35+
var parsedBlockId: String = ""
36+
var parsedBlockData: String = ""
37+
val handler = new BlockFetchingClientHandler
38+
handler.blockFetchSuccessCallback = (bid, refCntBuf) => {
39+
parsedBlockId = bid
40+
val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
41+
refCntBuf.byteBuffer().get(bytes)
42+
parsedBlockData = new String(bytes)
43+
}
44+
45+
val channel = new EmbeddedChannel(handler)
46+
val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
47+
buf.putInt(totalLength)
48+
buf.putInt(blockId.length)
49+
buf.put(blockId.getBytes)
50+
buf.put(blockData.getBytes)
51+
buf.flip()
52+
53+
channel.writeInbound(Unpooled.wrappedBuffer(buf))
54+
assert(parsedBlockId === blockId)
55+
assert(parsedBlockData === blockData)
56+
57+
channel.close()
58+
}
59+
60+
test("handling error message (failed fetch)") {
61+
val blockId = "test_block"
62+
val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
63+
val totalLength = 4 + blockId.length + errorMsg.length
64+
65+
var parsedBlockId: String = ""
66+
var parsedErrorMsg: String = ""
67+
val handler = new BlockFetchingClientHandler
68+
handler.blockFetchFailureCallback = (bid, msg) => {
69+
parsedBlockId = bid
70+
parsedErrorMsg = msg
71+
}
72+
73+
val channel = new EmbeddedChannel(handler)
74+
val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
75+
buf.putInt(totalLength)
76+
buf.putInt(-blockId.length)
77+
buf.put(blockId.getBytes)
78+
buf.put(errorMsg.getBytes)
79+
buf.flip()
80+
81+
channel.writeInbound(Unpooled.wrappedBuffer(buf))
82+
assert(parsedBlockId === blockId)
83+
assert(parsedErrorMsg === errorMsg)
84+
85+
channel.close()
86+
}
87+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty.server
19+
20+
import io.netty.buffer.ByteBuf
21+
import io.netty.channel.embedded.EmbeddedChannel
22+
23+
import org.scalatest.FunSuite
24+
25+
26+
class BlockHeaderEncoderSuite extends FunSuite {
27+
28+
test("encode normal block data") {
29+
val blockId = "test_block"
30+
val channel = new EmbeddedChannel(new BlockHeaderEncoder)
31+
channel.writeOutbound(new BlockHeader(17, blockId, None))
32+
val out = channel.readOutbound().asInstanceOf[ByteBuf]
33+
assert(out.readInt() === 4 + blockId.length + 17)
34+
assert(out.readInt() === blockId.length)
35+
36+
val blockIdBytes = new Array[Byte](blockId.length)
37+
out.readBytes(blockIdBytes)
38+
assert(new String(blockIdBytes) === blockId)
39+
assert(out.readableBytes() === 0)
40+
41+
channel.close()
42+
}
43+
44+
test("encode error message") {
45+
val blockId = "error_block"
46+
val errorMsg = "error encountered"
47+
val channel = new EmbeddedChannel(new BlockHeaderEncoder)
48+
channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
49+
val out = channel.readOutbound().asInstanceOf[ByteBuf]
50+
assert(out.readInt() === 4 + blockId.length + errorMsg.length)
51+
assert(out.readInt() === -blockId.length)
52+
53+
val blockIdBytes = new Array[Byte](blockId.length)
54+
out.readBytes(blockIdBytes)
55+
assert(new String(blockIdBytes) === blockId)
56+
57+
val errorMsgBytes = new Array[Byte](errorMsg.length)
58+
out.readBytes(errorMsgBytes)
59+
assert(new String(errorMsgBytes) === errorMsg)
60+
assert(out.readableBytes() === 0)
61+
62+
channel.close()
63+
}
64+
}

0 commit comments

Comments
 (0)