Skip to content
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5690fed
replace byte array with ChunkedByteBuffer in single task result colle…
liuzqt Oct 3, 2022
91695ef
Merge branch 'master' into SPARK-40622
liuzqt Oct 10, 2022
f9151a4
Update core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer…
liuzqt Oct 10, 2022
f486dd8
minor refine
liuzqt Oct 10, 2022
ec9c738
extract zero-copy writing in ChunkedByteBuffer as common util function
liuzqt Oct 10, 2022
fbbe788
use ByteBuffer for serializedResult
liuzqt Oct 10, 2022
babc68b
refine
liuzqt Oct 12, 2022
e557605
remove writeToStream
liuzqt Oct 12, 2022
e37f684
use getChunks
liuzqt Oct 12, 2022
985ccf8
refine
liuzqt Oct 13, 2022
5ddd527
always deserialize chunks into on-heap buffer
liuzqt Oct 13, 2022
364eae0
try to estimate chunk size in serialization
liuzqt Oct 14, 2022
5c64664
refine
liuzqt Oct 20, 2022
71dd6aa
update estimateBufferChunkSize
liuzqt Oct 25, 2022
b82a6ba
nit
liuzqt Oct 25, 2022
8153b77
try to estimate a reasonable upper bound of DirectTaskResult serializ…
liuzqt Oct 25, 2022
890ca25
fix test
liuzqt Oct 26, 2022
40ffef1
refactor: use Utils.writeByteBuffer
liuzqt Oct 26, 2022
f34755f
set initial value to ThreadLocal
liuzqt Oct 31, 2022
a22cae8
Merge branch 'master' into SPARK-40622
liuzqt Nov 1, 2022
2e68228
Merge remote-tracking branch 'upstream/master' into SPARK-40622
liuzqt Nov 4, 2022
cff1231
fix chinese double quote
liuzqt Nov 7, 2022
f2f47f8
try to remove scaladoc brackets introduced in this PR
liuzqt Nov 7, 2022
b2d4337
add org/apache/spark/util/io to Unidoc.ignoreUndocumentedPackages
liuzqt Nov 7, 2022
30f1805
fix KryoSerializerResizableOutputSuite
liuzqt Nov 8, 2022
5323c94
change LargeResult size to ~2.1GB
liuzqt Nov 8, 2022
b2b933b
remove large result test
liuzqt Nov 10, 2022
0853b95
Revert "remove large result test"
liuzqt Nov 14, 2022
20e2bcc
increase jvm mem to 6G in SparkBuild
liuzqt Nov 14, 2022
ae9c12b
set jvm mem to 5GB
liuzqt Nov 15, 2022
17d7ac7
skip DatasetLargeResultCollectingSuite in Github Action
liuzqt Nov 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
import org.apache.spark.serializer.SerializerHelper
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer

/**
* Spark executor, backed by a threadpool to run tasks.
Expand Down Expand Up @@ -172,7 +172,7 @@ private[spark] class Executor(
env.serializerManager.setDefaultClassLoader(replClassLoader)

// Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back.
// to send the result back. This is guaranteed to be smaller than array bytes limit (2GB)
private val maxDirectResultSize = Math.min(
conf.get(TASK_MAX_DIRECT_RESULT_SIZE),
RpcUtils.maxMessageSizeBytes(conf))
Expand Down Expand Up @@ -596,7 +596,7 @@ private[spark] class Executor(

val resultSer = env.serializer.newInstance()
val beforeSerializationNs = System.nanoTime()
val valueBytes = resultSer.serialize(value)
val valueByteBuffer = SerializerHelper.serializeToChunkedBuffer(resultSer, value)
val afterSerializationNs = System.nanoTime()

// Deserialization happens in two parts: first, we deserialize a Task object, which
Expand Down Expand Up @@ -659,9 +659,11 @@ private[spark] class Executor(
val accumUpdates = task.collectAccumulatorUpdates()
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit()
val directResult = new DirectTaskResult(valueByteBuffer, accumUpdates, metricPeaks)
// try to estimate a reasonable upper bound of DirectTaskResult serialization
val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult,
valueByteBuffer.size + accumUpdates.size * 32 + metricPeaks.length * 8)
val resultSize = serializedDirectResult.size

// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
Expand All @@ -674,13 +676,14 @@ private[spark] class Executor(
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
serializedDirectResult,
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
serializedDirectResult
// toByteBuffer is safe here, guarded by maxDirectResultSize
serializedDirectResult.toByteBuffer
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,8 @@ package object config {
ConfigBuilder("spark.task.maxDirectResultSize")
.version("2.0.0")
.bytesConf(ByteUnit.BYTE)
.checkValue(_ < ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong,
"The max direct result size is 2GB")
.createWithDefault(1L << 20)

private[spark] val TASK_MAX_FAILURES =
Expand Down
27 changes: 16 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,49 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.metrics.ExecutorMetricType
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{AccumulatorV2, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer

// Task result. Also contains updates to accumulator variables and executor metric peaks.
private[spark] sealed trait TaskResult[T]

/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Long)
extends TaskResult[T] with Serializable

/** A TaskResult that contains the task's return value, accumulator updates and metric peaks. */
private[spark] class DirectTaskResult[T](
var valueBytes: ByteBuffer,
var valueByteBuffer: ChunkedByteBuffer,
var accumUpdates: Seq[AccumulatorV2[_, _]],
var metricPeaks: Array[Long])
extends TaskResult[T] with Externalizable {

private var valueObjectDeserialized = false
private var valueObject: T = _

def this() = this(null.asInstanceOf[ByteBuffer], null,
def this(
valueByteBuffer: ByteBuffer,
accumUpdates: Seq[AccumulatorV2[_, _]],
metricPeaks: Array[Long]) = {
this(new ChunkedByteBuffer(Array(valueByteBuffer)), accumUpdates, metricPeaks)
}

def this() = this(null.asInstanceOf[ChunkedByteBuffer], Seq(),
new Array[Long](ExecutorMetricType.numMetrics))

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(valueBytes.remaining)
Utils.writeByteBuffer(valueBytes, out)
valueByteBuffer.writeExternal(out)
out.writeInt(accumUpdates.size)
accumUpdates.foreach(out.writeObject)
out.writeInt(metricPeaks.length)
metricPeaks.foreach(out.writeLong)
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
in.readFully(byteVal)
valueBytes = ByteBuffer.wrap(byteVal)
valueByteBuffer = new ChunkedByteBuffer()
valueByteBuffer.readExternal(in)

val numUpdates = in.readInt
if (numUpdates == 0) {
Expand Down Expand Up @@ -100,7 +105,7 @@ private[spark] class DirectTaskResult[T](
// This should not run when holding a lock because it may cost dozens of seconds for a large
// value
val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
valueObject = ser.deserialize(valueBytes)
valueObject = SerializerHelper.deserializeFromChunkedBuffer(ser, valueByteBuffer)
valueObjectDeserialized = true
valueObject
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}

/**
Expand Down Expand Up @@ -63,7 +63,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
case directResult: DirectTaskResult[_] =>
if (!taskSetManager.canFetchMoreResults(directResult.valueBytes.limit())) {
if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) {
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
"Tasks result size has exceeded maxResultSize"))
Expand All @@ -73,7 +73,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// We should call it here, so that when it's called again in
// "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
directResult.value(taskResultSerializer.get())
(directResult, serializedData.limit())
(directResult, serializedData.limit().toLong)
case IndirectTaskResult(blockId, size) =>
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
Expand All @@ -94,8 +94,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
return
}
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get.toByteBuffer)
val deserializedResult = SerializerHelper
.deserializeFromChunkedBuffer[DirectTaskResult[_]](
serializer.get(),
serializedTaskResult.get)
// force deserialization of referenced value
deserializedResult.value(taskResultSerializer.get())
sparkEnv.blockManager.master.removeBlock(blockId)
Expand All @@ -109,7 +111,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
val acc = a.asInstanceOf[LongAccumulator]
assert(acc.sum == 0L, "task result size should not have been set on the executors")
acc.setValue(size.toLong)
acc.setValue(size)
acc
} else {
a
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

private[spark] object SerializerHelper extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if this is specific to the chunked byte buffer, would it be beneficial to add this to ChunkedByteBuffer as a static method, e.g. ChunkedByteBuffer.serialize(...)? Just a question to see if there is a utility in the separate class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, the SerializerInstance should support ByteBuffer/Stream/ChunkedByteBuffer, but that would add new method to the SerializerInstance base class. To avoid that, I put this in a helper class (another option is to put it in SerializerInstance's companion object). But I think this method should belong to serializer-related class...we can leave it here or move to SerializerInstance's companion object


/**
*
* @param serializerInstance instance of SerializerInstance
* @param objectToSerialize the object to serialize, of type `T`
* @param estimatedSize estimated size of `t`, used as a hint to choose proper chunk size
*/
def serializeToChunkedBuffer[T: ClassTag](
serializerInstance: SerializerInstance,
objectToSerialize: T,
estimatedSize: Long = -1): ChunkedByteBuffer = {
val chunkSize = ChunkedByteBuffer.estimateBufferChunkSize(estimatedSize)
val cbbos = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate)
val out = serializerInstance.serializeStream(cbbos)
out.writeObject(objectToSerialize)
out.close()
cbbos.close()
cbbos.toChunkedByteBuffer
}

def deserializeFromChunkedBuffer[T: ClassTag](
serializerInstance: SerializerInstance,
bytes: ChunkedByteBuffer): T = {
val in = serializerInstance.deserializeStream(bytes.toInputStream())
in.readObject()
}
}
45 changes: 28 additions & 17 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ private[spark] object Utils extends Logging {

private val PATTERN_FOR_COMMAND_LINE_ARG = "-D(.+?)=(.+)".r

private val COPY_BUFFER_LEN = 1024

private val copyBuffer = ThreadLocal.withInitial[Array[Byte]](() => {
new Array[Byte](COPY_BUFFER_LEN)
})

/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
Expand Down Expand Up @@ -237,34 +243,39 @@ private[spark] object Utils extends Logging {
}
}

/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
private def writeByteBufferImpl(bb: ByteBuffer, writer: (Array[Byte], Int, Int) => Unit): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
// Avoid extra copy if the bytebuffer is backed by bytes array
writer(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
// Fallback to copy approach
val buffer = {
// reuse the copy buffer from thread local
copyBuffer.get()
}
val originalPosition = bb.position()
val bbval = new Array[Byte](bb.remaining())
bb.get(bbval)
out.write(bbval)
var bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN)
while (bytesToCopy > 0) {
bb.get(buffer, 0, bytesToCopy)
writer(buffer, 0, bytesToCopy)
bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN)
}
bb.position(originalPosition)
}
}

/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
writeByteBufferImpl(bb, out.write)
}

/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
*/
def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
val originalPosition = bb.position()
val bbval = new Array[Byte](bb.remaining())
bb.get(bbval)
out.write(bbval)
bb.position(originalPosition)
}
writeByteBufferImpl(bb, out.write)
}

/**
Expand Down
Loading