diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..e2ddf0c7a6914 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -129,6 +129,8 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val chunkSize = + conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. @@ -659,6 +661,11 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + // TODO if we change this method to return the ManagedBuffer, then getRemoteValues + // could just use the inputStream on the temp file, rather than memory-mapping the file. + // Until then, replication can cause the process to use too much memory and get killed + // by the OS / cluster manager (not a java OOM, since its a memory-mapped file) even though + // we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 @@ -689,7 +696,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -723,7 +730,7 @@ private[spark] class BlockManager( } if (data != null) { - return Some(new ChunkedByteBuffer(data)) + return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize)) } logDebug(s"The value of block $blockId is null") } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 700ce56466c35..efed90cb7678e 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -17,17 +17,21 @@ package org.apache.spark.util.io -import java.io.InputStream +import java.io.{File, FileInputStream, InputStream} import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel +import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption + +import scala.collection.mutable.ListBuffer import com.google.common.primitives.UnsignedBytes -import io.netty.buffer.{ByteBuf, Unpooled} import org.apache.spark.SparkEnv import org.apache.spark.internal.config +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.storage.StorageUtils +import org.apache.spark.util.Utils /** * Read-only byte buffer which is physically stored as multiple chunks rather than a single @@ -81,10 +85,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Wrap this buffer to view it as a Netty ByteBuf. + * Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB. */ - def toNetty: ByteBuf = { - Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) + def toNetty: ChunkedByteBufferFileRegion = { + new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize) } /** @@ -166,6 +170,34 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } +object ChunkedByteBuffer { + // TODO eliminate this method if we switch BlockManager to getting InputStreams + def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = { + data match { + case f: FileSegmentManagedBuffer => + map(f.getFile, maxChunkSize, f.getOffset, f.getLength) + case other => + new ChunkedByteBuffer(other.nioByteBuffer()) + } + } + + def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { + Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => + var remaining = length + var pos = offset + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, maxChunkSize) + val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize) + pos += chunkSize + remaining -= chunkSize + chunks += chunk + } + new ChunkedByteBuffer(chunks.toArray) + } + } +} + /** * Reads data from a ChunkedByteBuffer. * diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala new file mode 100644 index 0000000000000..9622d0ac05368 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.io + +import java.nio.channels.WritableByteChannel + +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.AbstractFileRegion + + +/** + * This exposes a ChunkedByteBuffer as a netty FileRegion, just to allow sending > 2gb in one netty + * message. This is because netty cannot send a ByteBuf > 2g, but it can send a large FileRegion, + * even though the data is not backed by a file. + */ +private[io] class ChunkedByteBufferFileRegion( + private val chunkedByteBuffer: ChunkedByteBuffer, + private val ioChunkSize: Int) extends AbstractFileRegion { + + private var _transferred: Long = 0 + // this duplicates the original chunks, so we're free to modify the position, limit, etc. + private val chunks = chunkedByteBuffer.getChunks() + private val size = chunks.foldLeft(0L) { _ + _.remaining() } + + protected def deallocate: Unit = {} + + override def count(): Long = size + + // this is the "start position" of the overall Data in the backing file, not our current position + override def position(): Long = 0 + + override def transferred(): Long = _transferred + + private var currentChunkIdx = 0 + + def transferTo(target: WritableByteChannel, position: Long): Long = { + assert(position == _transferred) + if (position == size) return 0L + var keepGoing = true + var written = 0L + var currentChunk = chunks(currentChunkIdx) + while (keepGoing) { + while (currentChunk.hasRemaining && keepGoing) { + val ioSize = Math.min(currentChunk.remaining(), ioChunkSize) + val originalLimit = currentChunk.limit() + currentChunk.limit(currentChunk.position() + ioSize) + val thisWriteSize = target.write(currentChunk) + currentChunk.limit(originalLimit) + written += thisWriteSize + if (thisWriteSize < ioSize) { + // the channel did not accept our entire write. We do *not* keep trying -- netty wants + // us to just stop, and report how much we've written. + keepGoing = false + } + } + if (keepGoing) { + // advance to the next chunk (if there are any more) + currentChunkIdx += 1 + if (currentChunkIdx == chunks.size) { + keepGoing = false + } else { + currentChunk = chunks(currentChunkIdx) + } + } + } + _transferred += written + written + } +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala new file mode 100644 index 0000000000000..a6b0654204f34 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io + +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel + +import scala.util.Random + +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.util.io.ChunkedByteBuffer + +class ChunkedByteBufferFileRegionSuite extends SparkFunSuite with MockitoSugar + with BeforeAndAfterEach { + + override protected def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf() + val env = mock[SparkEnv] + SparkEnv.set(env) + when(env.conf).thenReturn(conf) + } + + override protected def afterEach(): Unit = { + SparkEnv.set(null) + } + + private def generateChunkedByteBuffer(nChunks: Int, perChunk: Int): ChunkedByteBuffer = { + val bytes = (0 until nChunks).map { chunkIdx => + val bb = ByteBuffer.allocate(perChunk) + (0 until perChunk).foreach { idx => + bb.put((chunkIdx * perChunk + idx).toByte) + } + bb.position(0) + bb + }.toArray + new ChunkedByteBuffer(bytes) + } + + test("transferTo can stop and resume correctly") { + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 9L) + val cbb = generateChunkedByteBuffer(4, 10) + val fileRegion = cbb.toNetty + + val targetChannel = new LimitedWritableByteChannel(40) + + var pos = 0L + // write the fileregion to the channel, but with the transfer limited at various spots along + // the way. + + // limit to within the first chunk + targetChannel.acceptNBytes = 5 + pos = fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 5) + + // a little bit further within the first chunk + targetChannel.acceptNBytes = 2 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 7) + + // past the first chunk, into the 2nd + targetChannel.acceptNBytes = 6 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 13) + + // right to the end of the 2nd chunk + targetChannel.acceptNBytes = 7 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 20) + + // rest of 2nd chunk, all of 3rd, some of 4th + targetChannel.acceptNBytes = 15 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 35) + + // now till the end + targetChannel.acceptNBytes = 5 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + + // calling again at the end should be OK + targetChannel.acceptNBytes = 20 + fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + } + + test(s"transfer to with random limits") { + val rng = new Random() + val seed = System.currentTimeMillis() + logInfo(s"seed = $seed") + rng.setSeed(seed) + val chunkSize = 1e4.toInt + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, rng.nextInt(chunkSize).toLong) + + val cbb = generateChunkedByteBuffer(50, chunkSize) + val fileRegion = cbb.toNetty + val transferLimit = 1e5.toInt + val targetChannel = new LimitedWritableByteChannel(transferLimit) + while (targetChannel.pos < cbb.size) { + val nextTransferSize = rng.nextInt(transferLimit) + targetChannel.acceptNBytes = nextTransferSize + fileRegion.transferTo(targetChannel, targetChannel.pos) + } + assert(0 === fileRegion.transferTo(targetChannel, targetChannel.pos)) + } + + /** + * This mocks a channel which only accepts a limited number of bytes at a time. It also verifies + * the written data matches our expectations as the data is received. + */ + private class LimitedWritableByteChannel(maxWriteSize: Int) extends WritableByteChannel { + val bytes = new Array[Byte](maxWriteSize) + var acceptNBytes = 0 + var pos = 0 + + override def write(src: ByteBuffer): Int = { + val length = math.min(acceptNBytes, src.remaining()) + src.get(bytes, 0, length) + acceptNBytes -= length + // verify we got the right data + (0 until length).foreach { idx => + assert(bytes(idx) === (pos + idx).toByte, s"; wrong data at ${pos + idx}") + } + pos += length + length + } + + override def isOpen: Boolean = true + + override def close(): Unit = {} + } + +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 2107559572d78..ff117b1c21cb1 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -34,7 +34,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { assert(emptyChunkedByteBuffer.getChunks().isEmpty) assert(emptyChunkedByteBuffer.toArray === Array.empty) assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0) - assert(emptyChunkedByteBuffer.toNetty.capacity() === 0) + assert(emptyChunkedByteBuffer.toNetty.count() === 0) emptyChunkedByteBuffer.toInputStream(dispose = false).close() emptyChunkedByteBuffer.toInputStream(dispose = true).close() }