Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5690fed
replace byte array with ChunkedByteBuffer in single task result colle…
liuzqt Oct 3, 2022
91695ef
Merge branch 'master' into SPARK-40622
liuzqt Oct 10, 2022
f9151a4
Update core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer…
liuzqt Oct 10, 2022
f486dd8
minor refine
liuzqt Oct 10, 2022
ec9c738
extract zero-copy writing in ChunkedByteBuffer as common util function
liuzqt Oct 10, 2022
fbbe788
use ByteBuffer for serializedResult
liuzqt Oct 10, 2022
babc68b
refine
liuzqt Oct 12, 2022
e557605
remove writeToStream
liuzqt Oct 12, 2022
e37f684
use getChunks
liuzqt Oct 12, 2022
985ccf8
refine
liuzqt Oct 13, 2022
5ddd527
always deserialize chunks into on-heap buffer
liuzqt Oct 13, 2022
364eae0
try to estimate chunk size in serialization
liuzqt Oct 14, 2022
5c64664
refine
liuzqt Oct 20, 2022
71dd6aa
update estimateBufferChunkSize
liuzqt Oct 25, 2022
b82a6ba
nit
liuzqt Oct 25, 2022
8153b77
try to estimate a reasonable upper bound of DirectTaskResult serializ…
liuzqt Oct 25, 2022
890ca25
fix test
liuzqt Oct 26, 2022
40ffef1
refactor: use Utils.writeByteBuffer
liuzqt Oct 26, 2022
f34755f
set initial value to ThreadLocal
liuzqt Oct 31, 2022
a22cae8
Merge branch 'master' into SPARK-40622
liuzqt Nov 1, 2022
2e68228
Merge remote-tracking branch 'upstream/master' into SPARK-40622
liuzqt Nov 4, 2022
cff1231
fix chinese double quote
liuzqt Nov 7, 2022
f2f47f8
try to remove scaladoc brackets introduced in this PR
liuzqt Nov 7, 2022
b2d4337
add org/apache/spark/util/io to Unidoc.ignoreUndocumentedPackages
liuzqt Nov 7, 2022
30f1805
fix KryoSerializerResizableOutputSuite
liuzqt Nov 8, 2022
5323c94
change LargeResult size to ~2.1GB
liuzqt Nov 8, 2022
b2b933b
remove large result test
liuzqt Nov 10, 2022
0853b95
Revert "remove large result test"
liuzqt Nov 14, 2022
20e2bcc
increase jvm mem to 6G in SparkBuild
liuzqt Nov 14, 2022
ae9c12b
set jvm mem to 5GB
liuzqt Nov 15, 2022
17d7ac7
skip DatasetLargeResultCollectingSuite in Github Action
liuzqt Nov 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
import org.apache.spark.serializer.SerializerHelper
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer

/**
* Spark executor, backed by a threadpool to run tasks.
Expand Down Expand Up @@ -172,7 +172,7 @@ private[spark] class Executor(
env.serializerManager.setDefaultClassLoader(replClassLoader)

// Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back.
// to send the result back. This is guaranteed to be smaller than array bytes limit (2GB)
private val maxDirectResultSize = Math.min(
conf.get(TASK_MAX_DIRECT_RESULT_SIZE),
RpcUtils.maxMessageSizeBytes(conf))
Expand Down Expand Up @@ -596,7 +596,7 @@ private[spark] class Executor(

val resultSer = env.serializer.newInstance()
val beforeSerializationNs = System.nanoTime()
val valueBytes = resultSer.serialize(value)
val valueByteBuffer = SerializerHelper.serializeToChunkedBuffer(resultSer, value)
val afterSerializationNs = System.nanoTime()

// Deserialization happens in two parts: first, we deserialize a Task object, which
Expand Down Expand Up @@ -659,9 +659,9 @@ private[spark] class Executor(
val accumUpdates = task.collectAccumulatorUpdates()
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit()
val directResult = new DirectTaskResult(valueByteBuffer, accumUpdates, metricPeaks)
val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike the earlier invocation of serializeToChunkedBuffer (L#599) , here we have a good estimate of the size - something to leverage and minimize the cost of serializeToChunkedBuffer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean we can explicitly choose the chunk size here to avoid too-small/to-large chunk?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do know valueByteBuffer.size - and we know what the serialization overhead of DirectTaskResult is.
Will give us a much tighter bound on the size.

val resultSize = serializedDirectResult.size

// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
Expand All @@ -674,13 +674,14 @@ private[spark] class Executor(
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
serializedDirectResult,
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
serializedDirectResult
// toByteBuffer is safe here, guarded by maxDirectResultSize
serializedDirectResult.toByteBuffer
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,8 @@ package object config {
ConfigBuilder("spark.task.maxDirectResultSize")
.version("2.0.0")
.bytesConf(ByteUnit.BYTE)
.checkValue(_ < ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong,
"The max direct result size is 2GB")
.createWithDefault(1L << 20)

private[spark] val TASK_MAX_FAILURES =
Expand Down
27 changes: 16 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,49 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.metrics.ExecutorMetricType
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{AccumulatorV2, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer

// Task result. Also contains updates to accumulator variables and executor metric peaks.
private[spark] sealed trait TaskResult[T]

/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Long)
extends TaskResult[T] with Serializable

/** A TaskResult that contains the task's return value, accumulator updates and metric peaks. */
private[spark] class DirectTaskResult[T](
var valueBytes: ByteBuffer,
var valueByteBuffer: ChunkedByteBuffer,
var accumUpdates: Seq[AccumulatorV2[_, _]],
var metricPeaks: Array[Long])
extends TaskResult[T] with Externalizable {

private var valueObjectDeserialized = false
private var valueObject: T = _

def this() = this(null.asInstanceOf[ByteBuffer], null,
def this(
valueByteBuffer: ByteBuffer,
accumUpdates: Seq[AccumulatorV2[_, _]],
metricPeaks: Array[Long]) = {
this(new ChunkedByteBuffer(Array(valueByteBuffer)), accumUpdates, metricPeaks)
}

def this() = this(null.asInstanceOf[ChunkedByteBuffer], Seq(),
new Array[Long](ExecutorMetricType.numMetrics))

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(valueBytes.remaining)
Utils.writeByteBuffer(valueBytes, out)
valueByteBuffer.writeExternal(out)
out.writeInt(accumUpdates.size)
accumUpdates.foreach(out.writeObject)
out.writeInt(metricPeaks.length)
metricPeaks.foreach(out.writeLong)
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
in.readFully(byteVal)
valueBytes = ByteBuffer.wrap(byteVal)
valueByteBuffer = new ChunkedByteBuffer()
valueByteBuffer.readExternal(in)

val numUpdates = in.readInt
if (numUpdates == 0) {
Expand Down Expand Up @@ -100,7 +105,7 @@ private[spark] class DirectTaskResult[T](
// This should not run when holding a lock because it may cost dozens of seconds for a large
// value
val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
valueObject = ser.deserialize(valueBytes)
valueObject = SerializerHelper.deserializeFromChunkedBuffer(ser, valueByteBuffer)
valueObjectDeserialized = true
valueObject
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}

/**
Expand Down Expand Up @@ -63,7 +63,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
case directResult: DirectTaskResult[_] =>
if (!taskSetManager.canFetchMoreResults(directResult.valueBytes.limit())) {
if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) {
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
"Tasks result size has exceeded maxResultSize"))
Expand All @@ -73,7 +73,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// We should call it here, so that when it's called again in
// "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
directResult.value(taskResultSerializer.get())
(directResult, serializedData.limit())
(directResult, serializedData.limit().toLong)
case IndirectTaskResult(blockId, size) =>
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
Expand All @@ -94,8 +94,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
return
}
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get.toByteBuffer)
val deserializedResult = SerializerHelper
.deserializeFromChunkedBuffer[DirectTaskResult[_]](
serializer.get(),
serializedTaskResult.get)
// force deserialization of referenced value
deserializedResult.value(taskResultSerializer.get())
sparkEnv.blockManager.master.removeBlock(blockId)
Expand All @@ -109,7 +111,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
val acc = a.asInstanceOf[LongAccumulator]
assert(acc.sum == 0L, "task result size should not have been set on the executors")
acc.setValue(size.toLong)
acc.setValue(size)
acc
} else {
a
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

private[spark] object SerializerHelper extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if this is specific to the chunked byte buffer, would it be beneficial to add this to ChunkedByteBuffer as a static method, e.g. ChunkedByteBuffer.serialize(...)? Just a question to see if there is a utility in the separate class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, the SerializerInstance should support ByteBuffer/Stream/ChunkedByteBuffer, but that would add new method to the SerializerInstance base class. To avoid that, I put this in a helper class (another option is to put it in SerializerInstance's companion object). But I think this method should belong to serializer-related class...we can leave it here or move to SerializerInstance's companion object

val CHUNK_BUFFER_SIZE: Int = 1024 * 1024

def serializeToChunkedBuffer[T: ClassTag](
serializerInstance: SerializerInstance,
t: T): ChunkedByteBuffer = {
val cbbos = new ChunkedByteBufferOutputStream(CHUNK_BUFFER_SIZE, ByteBuffer.allocate)
val out = serializerInstance.serializeStream(cbbos)
out.writeObject(t)
out.close()
cbbos.close()
cbbos.toChunkedByteBuffer
}

def deserializeFromChunkedBuffer[T: ClassTag](
serializerInstance: SerializerInstance,
bytes: ChunkedByteBuffer): T = {
val in = serializerInstance.deserializeStream(bytes.toInputStream())
in.readObject()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.util.io

import java.io.{File, FileInputStream, InputStream}
import java.io.{Externalizable, File, FileInputStream, InputStream, ObjectInput, ObjectOutput, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.WritableByteChannel

Expand All @@ -42,8 +42,9 @@ import org.apache.spark.util.Utils
* buffers may also be used elsewhere then the caller is responsible for copying
* them as needed.
*/
private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Externalizable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that these are copies of the byte buffers so there is no other reference that can be modified outside of the ChunkedByteBuffer based on your changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no physical bytes buffer copy happening here, so the chunks underlying bytes buffer are actually still modifiable outside, and I think it's the caller's responsibility to manage the ownership. (But it is what it was, not introduced by this PR, I think the idea of ChunkedByteBuffer is more like a "view" of buffer chunks)

But this PR can guarantee that the serialization won't modify any state of the ChunkedByteBuffer.

Copy link
Contributor

@sadikovi sadikovi Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that, I was just asking if the references are maintained outside of this class, seems like it is not an issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh got it. Yes in the case of this PR this should not be an issue since the chunk buffer is used within a pretty short life cycle during serialization/IO and reference ownership is clear. Thanks for the question.

require(chunks != null, "chunks must not be null")
require(!chunks.contains(null), "chunks must not contain null")
Copy link
Contributor

@mridulm mridulm Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next line should have implicitly checked for it - was this getting triggered ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ChunkeByteBuffer has a constructor overload def this(byteBuffer: ByteBuffer), which is invoked with null in the DirectTaskResult empty constructor.

It was triggered in some temp code snippet during my dev, but should not have any usage in the current code base(at least can not be found explicitly). That empty constructor might serve as default constructor for some serialization code I guess.......After all, the code in ChunkeByteBuffer are assuming chunks doesn't contains null, so I think it's no harm to do this check at the beginning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is another iteration over chunks as part of construction. Not a big deal, but I was looking to minimize impact of the change.

require(chunks.forall(_.position() == 0), "chunks' positions must be 0")

// Chunk size in bytes
Expand All @@ -56,7 +57,13 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
/**
* This size of this buffer, in bytes.
*/
val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
var _size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum

def size: Long = _size

def this() = {
this(Array.empty[ByteBuffer])
}

def this(byteBuffer: ByteBuffer) = {
this(Array(byteBuffer))
Expand Down Expand Up @@ -84,6 +91,74 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
}
}

/**
* write to ObjectOutput with zero copy if possible
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Writes to the provided ObjectOutput with zero copy if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you elaborate on zero copy? It seems a copy is required regardless.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion, the zero copy is actually "avoid extra copy".

  • When we have access to the underlying byte array, we only need one copy: underlying bytes array --> ObjectOutput.
  • On the other hand, when we don't have access to the underlying byte array(when it's off-heap), we need an extra copy: bytes buffer --> temp buffer ---> ObjectOutput.

*/
override def writeExternal(out: ObjectOutput): Unit = {
// we want to keep the chunks layout
out.writeInt(chunks.length)
chunks.foreach(buffer => out.writeInt(buffer.limit()))
chunks.foreach(buffer => out.writeBoolean(buffer.isDirect))
var buffer: Array[Byte] = null
val bufferLen = ChunkedByteBuffer.COPY_BUFFER_LEN

getChunks().foreach { chunk => {
if (chunk.hasArray) {
// zero copy if the bytebuffer is backed by bytes array
out.write(chunk.array(), chunk.arrayOffset(), chunk.limit())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this would still incur a copy, unless the implementation stores the reference and position.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the array() will returns the reference of the byte array that backs this buffer, which should not incur extra copy

} else {
// fallback to copy approach
if (buffer == null) {
buffer = new Array[Byte](bufferLen)
}
var bytesToRead = Math.min(chunk.remaining(), bufferLen)
while (bytesToRead > 0) {
chunk.get(buffer, 0, bytesToRead)
out.write(buffer, 0, bytesToRead)
bytesToRead = Math.min(chunk.remaining(), bufferLen)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Pull this out into a separate method ?
I will need to search our codebase, but I would have expected this snippet to be already there somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be... but I was not able to locate them....not very familiar with the code base. Feel free to let me know if you find them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is Utils.writeByteBuffer which writes a buffer to an OutputStream.
It does not do a very good job, since it allocates a buffer the same size as remaining - instead, we should enhance it to do what this method is doing.

Additionally, we can use a ThreadLocal[Array[Byte]] in Utils for use with this copy (the buffer here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to reuse Utils.writeByteBuffer, and I noticed that there're two versions of Utils.writeByteBuffer for OutputStream and DataOutput respectively and the code is identical so added a Utils.writeByteBufferImpl to extract the common logic, also added ThreadLocal[Array[Byte]]

}
}}
}

override def readExternal(in: ObjectInput): Unit = {
val chunksNum = in.readInt()
val indices = 0 until chunksNum
val chunksSize = indices.map(_ => in.readInt())
val chunksDirect = indices.map(_ => in.readBoolean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have cases where we want to preserve whether the buffer was direct or not across VM ? The current usecase does not require it ?
+CC @Ngone51 ?

If not, drop this and simplify the impl ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have cases where we want to preserve whether the buffer was direct or not across VM ?

I don't have such cases in my mind..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we default to on-heap buffer in deserialization regardless what it was before serialization? @mridulm @Ngone51

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let us do that - if some specific reason comes up in future, we can add support as required.
Will make the implementation much more simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, updated.

Copy link
Contributor

@mridulm mridulm Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are still preserving chunksDirect, remove it ? And always deserialize to heap buffer via ByteBuffer.wrap ?

val chunks = new Array[ByteBuffer](chunksNum)

val copyBufferLen = ChunkedByteBuffer.COPY_BUFFER_LEN
val copyBuffer: Array[Byte] = if (chunksDirect.exists(identity)) {
new Array[Byte](copyBufferLen)
} else {
null
}

indices.foreach { i => {
val chunkSize = chunksSize(i)
chunks(i) = if (chunksDirect(i)) {
val buffer = ByteBuffer.allocateDirect(chunkSize)
var bytesRemaining = chunkSize
var bytesToRead = Math.min(copyBufferLen, bytesRemaining)
while (bytesRemaining > 0) {
bytesToRead = Math.min(copyBufferLen, bytesRemaining)
in.readFully(copyBuffer, 0, bytesToRead)
buffer.put(copyBuffer, 0, bytesToRead)
bytesRemaining -= bytesToRead
}
buffer.rewind()
buffer
} else {
val arr = new Array[Byte](chunkSize)
in.readFully(arr, 0, chunkSize)
ByteBuffer.wrap(arr)
}
}}
this.chunks = chunks
this._size = chunks.map(_.limit().toLong).sum
}

/**
* Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB.
*/
Expand Down Expand Up @@ -172,6 +247,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {

private[spark] object ChunkedByteBuffer {

val COPY_BUFFER_LEN: Int = 1024 * 1024
Copy link
Contributor

@mridulm mridulm Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be messy to add a dependency from here to SerializerHelper - add a comment in both places to indicate they should be ideally kept in sync ?
And make both of them private to the class ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a def estimateBufferChunkSize(estimatedSize: Long = -1) to be used for both. But I'm not sure if the heuristic is appropriate.....

Or another option: we can use 1024(1KB) for all, make it simple. I did some quick benchmark, 1KB isn't too bad compared to 1MB even in large result, and the overhead upper bound is reasonable even when result is very tiny(actually even a nearly empty result will still be serialized to a few hundred bytes because of other metrics and accumulators)

WDYT?

Copy link
Contributor

@mridulm mridulm Oct 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment which might help ...


def fromManagedBuffer(data: ManagedBuffer): ChunkedByteBuffer = {
data match {
case f: FileSegmentManagedBuffer =>
Expand Down
Loading