diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f25d346e6e5..f8164c14f3f7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -42,7 +42,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.broadcast.{Broadcast, BroadcastMode} import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging @@ -1487,6 +1487,27 @@ class SparkContext(config: SparkConf) extends Logging { bc } + /** + * :: DeveloperApi :: + * Broadcast a read-only variable to the cluster, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. + * + * Notice that the RDD to be broadcasted should be cached and materilized first so we can + * access its data on the executors. + */ + @DeveloperApi + def broadcastRDDOnExecutor[T: ClassTag, U: ClassTag]( + rdd: RDD[T], mode: BroadcastMode[T]): Broadcast[U] = { + assertNotStopped() + val bc = env.broadcastManager.newBroadcastOnExecutor[T, U](rdd, mode, isLocal) + rdd.broadcast(bc) + val callSite = getCallSite + logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) + cleaner.foreach(_.registerBroadcastForCleanup(bc)) + bc + } + /** * Add a file to be downloaded with this Spark job on every node. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index ece4ae6ab031..b597decd3ae8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -19,8 +19,8 @@ package org.apache.spark.broadcast import scala.reflect.ClassTag -import org.apache.spark.SecurityManager -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.rdd.RDD /** * An interface for all the broadcast implementations in Spark (to allow @@ -40,6 +40,21 @@ private[spark] trait BroadcastFactory { */ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + /** + * Creates a new broadcast variable which is broadcasted on executors without collecting first + * to the driver. + * + * @param rdd the RDD to be broadcasted among executors + * @param mode the broadcast mode used to transform the result of RDD to broadcasted object + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ + def newBroadcastOnExecutor[T: ClassTag, U: ClassTag]( + rdd: RDD[T], + mode: BroadcastMode[T], + isLocal: Boolean, + id: Long): Broadcast[U] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2..aee16ac3f153 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD private[spark] class BroadcastManager( val isDriver: Boolean, @@ -56,6 +57,14 @@ private[spark] class BroadcastManager( broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } + def newBroadcastOnExecutor[T: ClassTag, U: ClassTag]( + rdd_ : RDD[T], + mode: BroadcastMode[T], + isLocal: Boolean): Broadcast[U] = { + broadcastFactory.newBroadcastOnExecutor[T, U](rdd_, mode, isLocal, + nextBroadcastId.getAndIncrement()) + } + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { broadcastFactory.unbroadcast(id, removeFromDriver, blocking) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastMode.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastMode.scala new file mode 100644 index 000000000000..ac42f03ecb22 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastMode.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +/** + * The trait used in executor side broadcast. The implementation of `transform` identify the shape + * in which the results of a RDD are broadcasted. + * + * @tparam T The type of RDD elements. + */ +trait BroadcastMode[T] extends Serializable { + def transform(rows: Array[T]): Any + def transform(rows: Iterator[T], sizeHint: Option[Long]): Any + def canonicalized: BroadcastMode[T] = this +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 67e993c7f02e..f67939f0089f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -209,6 +209,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) val blockManager = SparkEnv.get.blockManager blockManager.getLocalValues(broadcastId) match { case Some(blockResult) => + // Found broadcasted value in local [[BlockManager]]. Use it directly. if (blockResult.data.hasNext) { val x = blockResult.data.next().asInstanceOf[T] releaseLock(broadcastId) @@ -217,6 +218,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } case None => + // Not found. Going to fetch the chunks of the broadcasted value from driver/executors. logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() val blocks = readBlocks() diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index b11f9ba171b8..35be3ce8a67e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -20,9 +20,10 @@ package org.apache.spark.broadcast import scala.reflect.ClassTag import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.rdd.RDD /** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a BitTorrent-like * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. */ @@ -34,6 +35,13 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory { new TorrentBroadcast[T](value_, id) } + override def newBroadcastOnExecutor[T: ClassTag, U: ClassTag]( + rdd: RDD[T], + mode: BroadcastMode[T], + isLocal: Boolean, id: Long): Broadcast[U] = { + new TorrentExecutorBroadcast[T, U](rdd, mode, id) + } + override def stop() { } /** diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentExecutorBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentExecutorBroadcast.scala new file mode 100644 index 000000000000..447593ee5ed0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentExecutorBroadcast.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io.ObjectOutputStream + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockId, BlockResult, BroadcastBlockId, RDDBlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * Different to [[TorrentBroadcast]], this implementation doesn't divide the object to broadcast. + * In contrast, this implementation performs broadcast on executor side for a RDD. So the results + * of the RDD does not need to collect first back to the driver before broadcasting. + * + * The mechanism is as follows: + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it doesn not exist, it then uses remote fetches to fetch the blocks of the RDD from other + * executors if available. Once it gets the blocks, it puts the blocks in its own BlockManager, + * ready for other executors to fetch from. + * + * @tparam T The type of the element of RDD to be broadcasted. + * @tparam U The type of object transformed from the collection of elements of the RDD. + * + * @param rdd The RDD to be broadcasted on executors. + * @param mode The [[org.apache.spark.broadcast.BroadcastMode]] object used to transform the result + * of RDD to the object which will be stored in block manager. + * @param id A unique identifier for the broadcast variable. + */ +private[spark] class TorrentExecutorBroadcast[T: ClassTag, U: ClassTag]( + @transient private val rdd: RDD[T], + mode: BroadcastMode[T], + id: Long) extends Broadcast[U](id) with Logging with Serializable { + + // Total number of blocks this broadcast variable contains. + private val numBlocks: Int = rdd.getNumPartitions + // The id of the RDD to be broadcasted on executors. + private val rddId: Int = rdd.id + + /** + * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], + * which builds this value by reading blocks from other executors. + */ + @transient private lazy val _value: U = readBroadcastBlock() + + private val broadcastId = BroadcastBlockId(id) + + override protected def getValue() = { + _value + } + + /** Fetch torrent blocks from other executors. */ + private def readBlocks(): Array[T] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[Array[T]](numBlocks) + val bm = SparkEnv.get.blockManager + + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = RDDBlockId(rddId, pid) + // First try getLocalValues because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + bm.getLocalValues(pieceId) match { + case Some(block: BlockResult) => + blocks(pid) = block.data.asInstanceOf[Iterator[T]].toArray + case None => + bm.get[T](pieceId) match { + case Some(b) => + val data = b.data.asInstanceOf[Iterator[T]].toArray + // We found the block from remote executors' BlockManager, so put the block + // in this executor's BlockManager. + if (!bm.putIterator(pieceId, data.toIterator, + StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } + blocks(pid) = data + case None => + throw new SparkException(s"Failed to get $pieceId of $broadcastId") + } + } + } + blocks.flatMap(x => x) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + override protected def doUnpersist(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on the executors + * and driver. + */ + override protected def doDestroy(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) + } + + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + assertValid() + out.defaultWriteObject() + } + + private def readBroadcastBlock(): U = Utils.tryOrIOException { + TorrentBroadcast.synchronized { + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId).map(_.data.next()) match { + case Some(x) => + // Found broadcasted value in local [[BlockManager]]. Use it directly. + releaseLock(broadcastId) + x.asInstanceOf[U] + + case None => + // Not found. Going to fetch the chunks of the broadcasted value from executors. + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val rawInput = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + val obj = mode.transform(rawInput.toArray).asInstanceOf[U] + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } + } + } + + /** + * If running in a task, register the given block's locks for release upon task completion. + * Otherwise, if not running in a task then immediately release the lock. + */ + private def releaseLock(blockId: BlockId): Unit = { + val blockManager = SparkEnv.get.blockManager + Option(TaskContext.get()) match { + case Some(taskContext) => + taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + case None => + // This should only happen on the driver, where broadcast variables may be accessed + // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow + // broadcast variables to be garbage collected we need to free the reference here + // which is slightly unsafe but is technically okay because broadcast variables aren't + // stored off-heap. + blockManager.releaseLock(blockId) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc92536..8ad9becdc29a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,6 +35,7 @@ import org.apache.spark._ import org.apache.spark.Partitioner._ import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -206,6 +207,24 @@ abstract class RDD[T: ClassTag]( */ def cache(): this.type = persist() + /** + * Broadcast this RDD on executors. The executor side broadcast variable is created by + * [[SparkContext]]. This RDD should be cached and materialized first before calling on + * this method. + */ + private[spark] def broadcast[U: ClassTag](broadcasted: Broadcast[U]): Unit = { + // The RDD should be cached and materialized before it can be executor side broadcasted. + // We do the checking here. + if (storageLevel == StorageLevel.NONE) { + throw new SparkException("To broadcast this RDD on executors, it should be cached first.") + } + // Create the executor side broadcast object on executors. + mapPartitionsInternal { iter: Iterator[T] => + broadcasted.value + Iterator.empty.asInstanceOf[Iterator[T]] + }.count + } + /** * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. * diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462..3c0bf248b230 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -29,7 +29,9 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.BroadcastMode import org.apache.spark.rdd.RDDSuiteUtils._ +import org.apache.spark.storage.BroadcastBlockId import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext { @@ -1113,6 +1115,58 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(totalPartitionCount == 10) } + test("executor side broadcast for RDD") { + // Materialize and cache the RDD to be broadcasted on executors. + val rdd = sc.parallelize(1 to 4, 2).cache() + rdd.count() + val mode = new BroadcastMode[Int] { + override def transform(rows: Array[Int]): Array[Int] = rows + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Array[Int] = rows.toArray + } + val broadcastedVal = sc.broadcastRDDOnExecutor[Int, Array[Int]](rdd, mode) + val collected = sc.parallelize(1 to 2, 2).map { _ => + broadcastedVal.value.reduce(_ + _) // 1 + 2 + 3 + 4 = 10 + }.collect() + assert(broadcastedVal.value.sum == 10) + assert(collected.sum == 20) + } + + test("executor side broadcast for RDD: unbroadcast") { + // Materialize and cache the RDD to be broadcasted on executors. + val rdd = sc.parallelize(1 to 4, 2).cache() + rdd.count() + val mode = new BroadcastMode[Int] { + override def transform(rows: Array[Int]): Int = 1 + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Int = 1 + } + val broadcastedVal = sc.broadcastRDDOnExecutor[Int, Int](rdd, mode) + val collected = sc.parallelize(1 to 2, 2).map { _ => + broadcastedVal.value + }.collect() + val blockId = BroadcastBlockId(broadcastedVal.id) + assert(sc.env.blockManager.getSingle(blockId).isDefined) + sc.env.blockManager.releaseLock(blockId) + // Unbroadcast it. + sc.env.broadcastManager.unbroadcast(broadcastedVal.id, true, true) + assert(sc.env.blockManager.getSingle(blockId).isEmpty) + } + + test("executor side broadcast for RDD: unpersist RDD") { + // Materialize and cache the RDD to be broadcasted on executors. + val rdd = sc.parallelize(1 to 4, 2).cache() + rdd.count() + val mode = new BroadcastMode[Int] { + override def transform(rows: Array[Int]): Int = 1 + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Int = 1 + } + val broadcastedVal = sc.broadcastRDDOnExecutor[Int, Int](rdd, mode) + rdd.unpersist() + val collected = sc.parallelize(1 to 2, 2).map { _ => + broadcastedVal.value + }.collect() + assert(collected.sum == 2) + } + test("SPARK-18406: race between end-of-task and completion iterator read lock release") { val rdd = sc.parallelize(1 to 1000, 10) rdd.cache() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9fac95aed8f1..16e380161624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -17,30 +17,25 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.broadcast.BroadcastMode import org.apache.spark.sql.catalyst.InternalRow /** * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index). */ -trait BroadcastMode { - def transform(rows: Array[InternalRow]): Any - - def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any - - def canonicalized: BroadcastMode +abstract class RowBroadcastMode extends BroadcastMode[InternalRow] { + override def canonicalized: RowBroadcastMode = this } /** * IdentityBroadcastMode requires that rows are broadcasted in their original form. */ -case object IdentityBroadcastMode extends BroadcastMode { +case object IdentityBroadcastMode extends RowBroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows override def transform( rows: Iterator[InternalRow], sizeHint: Option[Long]): Array[InternalRow] = rows.toArray - - override def canonicalized: BroadcastMode = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e57c842ce2a3..3ee10355723a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -81,7 +82,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution +case class BroadcastDistribution(mode: RowBroadcastMode) extends Distribution /** * Describes how an operator's output is split across partitions. The `compatibleWith`, @@ -370,7 +371,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Represents a partitioning where rows are collected, transformed and broadcasted to each * node in the cluster. */ -case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { +case class BroadcastPartitioning(mode: RowBroadcastMode) extends Partitioning { override val numPartitions: Int = 1 override def satisfies(required: Distribution): Boolean = required match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee1..5807ef580498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -27,6 +27,7 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.broadcast.BroadcastMode import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.TableIdentifier @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -690,7 +691,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case catalog: CatalogTable => true case partition: Partitioning => true case resource: FunctionResource => true - case broadcast: BroadcastMode => true + case broadcast: BroadcastMode[_] => true case table: CatalogTableType => true case storage: CatalogStorageFormat => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ede116e964a0..cf5bd5e746dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -409,6 +409,16 @@ object SQLConf { .timeConf(TimeUnit.SECONDS) .createWithDefault(5 * 60) + val EXECUTOR_SIDE_BROADCAST_ENABLED = buildConf("spark.sql.executorSideBroadcast.enabled") + .doc("When true, we will use executor side broadcast for Broadcast-based join in sql. " + + "Notice that broadcasted pieces of data in executor-side broadcast are not persisted " + + "in the driver, but fetched from RDD pieces persisted in other executors. " + + "If one executor is lost before its piece is fetched by other executors, " + + "we can't recover it back and broadcasting will be failed. Thus it is not " + + "guaranteed completely safe when using with dynamic allocation.") + .booleanConf + .createWithDefault(true) + // This is only used for the thriftserver val THRIFTSERVER_POOL = buildConf("spark.sql.thriftserver.scheduler.pool") .doc("Set a Fair Scheduler pool for a JDBC client session.") @@ -1151,6 +1161,8 @@ class SQLConf extends Serializable with Logging { def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) + def executorSideBroadcastEnabled: Boolean = getConf(EXECUTOR_SIDE_BROADCAST_ENABLED) + def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) def convertCTAS: Boolean = getConf(CONVERT_CTAS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d..c2b91a67b4e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -19,32 +19,42 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ +import scala.reflect.ClassTag import org.apache.spark.{broadcast, SparkException} import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastPartitioning, Partitioning, RowBroadcastMode} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ThreadUtils /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of * a transformed SparkPlan. + * + * @tparam T The type of the object transformed from the result of RDD by [[BroadcastMode]]. */ -case class BroadcastExchangeExec( - mode: BroadcastMode, +case class BroadcastExchangeExec[T: ClassTag]( + mode: RowBroadcastMode, child: SparkPlan) extends Exchange { - override lazy val metrics = Map( - "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), - "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"), - "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"), - "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)")) + override lazy val metrics = if (sqlContext.conf.executorSideBroadcastEnabled) { + Map( + "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"), + "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)")) + } else { + Map( + "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), + "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"), + "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"), + "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)")) + } override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) @@ -62,6 +72,74 @@ case class BroadcastExchangeExec( } } + // Private variable used to hold the reference of RDD created during executor-side broadcasting. + // If we don't keep its reference, it will be cleaned up. + private var childRDD: RDD[InternalRow] = null + + private def executorSideBroadcast(): broadcast.Broadcast[Any] = { + val beforeBuild = System.nanoTime() + // Call persist on the RDD because we want to broadcast the RDD blocks on executors. + childRDD = child.execute().mapPartitionsInternal { rowIterator => + rowIterator.map(_.copy()) + }.persist(StorageLevel.MEMORY_AND_DISK) + + val numOfRows = childRDD.count() + if (numOfRows >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with more than 512 millions rows: ${numOfRows} rows") + } + + // Broadcast the relation on executors. + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += (beforeBuild - beforeBroadcast) / 1000000 + + val broadcasted = sparkContext.broadcastRDDOnExecutor[InternalRow, T](childRDD, mode) + .asInstanceOf[broadcast.Broadcast[Any]] + + longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 + broadcasted + } + + private def driverSideBroadcast(): broadcast.Broadcast[Any] = { + val beforeCollect = System.nanoTime() + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with more than 512 millions rows: $numRows rows") + } + + val beforeBuild = System.nanoTime() + longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + + relation.getClass.getName) + } + + longMetric("dataSize") += dataSize + if (dataSize >= (8L << 30)) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } + + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + + // Broadcast the relation + val broadcasted = sparkContext.broadcast(relation) + longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 + broadcasted + } + @transient private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. @@ -71,43 +149,11 @@ case class BroadcastExchangeExec( // with the correct execution. SQLExecution.withExecutionId(sparkContext, executionId) { try { - val beforeCollect = System.nanoTime() - // Use executeCollect/executeCollectIterator to avoid conversion to Scala types - val (numRows, input) = child.executeCollectIterator() - if (numRows >= 512000000) { - throw new SparkException( - s"Cannot broadcast the table with more than 512 millions rows: $numRows rows") + val broadcasted = if (sqlContext.conf.executorSideBroadcastEnabled) { + executorSideBroadcast() + } else { + driverSideBroadcast() } - - val beforeBuild = System.nanoTime() - longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - - // Construct the relation. - val relation = mode.transform(input, Some(numRows)) - - val dataSize = relation match { - case map: HashedRelation => - map.estimatedSize - case arr: Array[InternalRow] => - arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum - case _ => - throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + - relation.getClass.getName) - } - - longMetric("dataSize") += dataSize - if (dataSize >= (8L << 30)) { - throw new SparkException( - s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") - } - - val beforeBroadcast = System.nanoTime() - longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 - - // Broadcast the relation - val broadcasted = sparkContext.broadcast(relation) - longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { @@ -135,6 +181,8 @@ case class BroadcastExchangeExec( override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] } + + override protected def otherCopyArgs: Seq[AnyRef] = Seq(implicitly[ClassTag[T]]) } object BroadcastExchangeExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4e2ca37bc1a5..0001a70108b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.exchange +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.internal.SQLConf /** @@ -163,7 +165,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) + mode match { + case IdentityBroadcastMode => BroadcastExchangeExec[Array[InternalRow]](mode, child) + case _: HashedRelationBroadcastMode => BroadcastExchangeExec[HashedRelation](mode, child) + } case (child, distribution) => ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b2dcbe5aa987..ffbb8c804c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.memory.{MemoryConsumer, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.catalyst.plans.physical.RowBroadcastMode import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap @@ -863,7 +863,7 @@ private[joins] object LongHashedRelation { /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) - extends BroadcastMode { + extends RowBroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { transform(rows.iterator, Some(rows.length)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 17c88b069080..099e9734199b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1786,7 +1786,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) assert( join2.queryExecution.executedPlan - .collect { case e: BroadcastExchangeExec => true }.size === 1) + .collect { case e: BroadcastExchangeExec[_] => true }.size === 1) assert( join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index aac8d56ba620..19528dafac4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -55,12 +56,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val output = plan.output assert(plan sameResult plan) - val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) + val exchange1 = BroadcastExchangeExec[Array[InternalRow]](IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) - val exchange2 = BroadcastExchangeExec(hashMode, plan) + val exchange2 = BroadcastExchangeExec[HashedRelation](hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) - val exchange3 = BroadcastExchangeExec(hashMode2, plan) + val exchange3 = BroadcastExchangeExec[HashedRelation](hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece11225..84666aa581e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -127,26 +127,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } - test(s"$testName using BroadcastHashJoin (build=left)") { + def usingBroadcastHashJoin(buildSide: joins.BuildSide): Unit = { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, buildSide), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } + test(s"$testName using BroadcastHashJoin (build=left)") { + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastHashJoin(joins.BuildLeft) + } + } + } + test(s"$testName using BroadcastHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastHashJoin(joins.BuildRight) } } } @@ -196,21 +200,28 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using BroadcastNestedLoopJoin build left") { + def usingBroadcastNestedLoopJoin(buildSide: joins.BuildSide): Unit = { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), + BroadcastNestedLoopJoinExec(left, right, buildSide, Inner, Some(condition())), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } + test(s"$testName using BroadcastNestedLoopJoin build left") { + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastNestedLoopJoin(BuildLeft) + } + } + } + test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastNestedLoopJoin(BuildRight) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b39..cc3ffd74dded 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -92,20 +92,28 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } + def usingBroadcastHashJoin(): Unit = { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + case _ => fail(s"Unsupported join type $joinType") + } + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + if (joinType != FullOuter) { test(s"$testName using BroadcastHashJoin") { - val buildSide = joinType match { - case LeftOuter => BuildRight - case RightOuter => BuildLeft - case _ => fail(s"Unsupported join type $joinType") - } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastHashJoin() } } } @@ -123,21 +131,28 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using BroadcastNestedLoopJoin build left") { + def usingBroadcastNestedLoopJoin(buildSide: BuildSide): Unit = { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)), + BroadcastNestedLoopJoinExec(left, right, buildSide, joinType, Some(condition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } + test(s"$testName using BroadcastNestedLoopJoin build left") { + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastNestedLoopJoin(BuildLeft) + } + } + } + test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastNestedLoopJoin(BuildRight) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 58a194b8af62..0c9bcb6631c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -225,10 +225,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan is // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) val df = df1.join(broadcast(df2), "key") - testSparkPlanMetrics(df, 2, Map( - 1L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L)))) - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + 1L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + ) + } } test("BroadcastHashJoin metrics: track avg probe") { @@ -256,20 +259,22 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Project(nodeId = 3) // Filter(nodeId = 4) // ...(ignored) - Seq(true, false).foreach { enableWholeStage => - val df1 = generateRandomBytesDF() - val df2 = generateRandomBytesDF() - val df = df1.join(broadcast(df2), "a") - val nodeIds = if (enableWholeStage) { - Set(2L) - } else { - Set(1L) - } - val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") - probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => - assert(probe.toDouble > 1.0) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF() + val df2 = generateRandomBytesDF() + val df = df1.join(broadcast(df2), "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } } } } @@ -346,32 +351,36 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan is // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") - testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 5L)))) - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + 0L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 5L)))) + ) - val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") - testSparkPlanMetrics(df3, 2, Map( - 0L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 6L)))) - ) + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 6L)))) + ) + } } test("BroadcastNestedLoopJoin metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") - withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - withTempView("testDataForJoin") { - // Assume the execution plan is - // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = spark.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") - testSparkPlanMetrics(df, 3, Map( - 1L -> (("BroadcastNestedLoopJoin", Map( - "number of output rows" -> 12L)))) - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withTempView("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = spark.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> (("BroadcastNestedLoopJoin", Map( + "number of output rows" -> 12L)))) + ) + } } } } @@ -382,10 +391,12 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan is // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L)))) - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + 0L -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L)))) + ) + } } test("CartesianProduct metrics") { @@ -474,21 +485,33 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared ) assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil) - withTempDir { tempDir => - val dir = new File(tempDir, "pqS").getCanonicalPath - - spark.range(10).write.parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("pqS") - - val res3 = InputOutputMetricsHelper.run( - spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() - ) - // The query above is executed in the following stages: - // 1. sql("select * from pqS") => (10, 0, 10) - // 2. range(30) => (30, 0, 30) - // 3. crossJoin(...) of 1. and 2. => (0, 30, 300) - // 4. shuffle & return results => (0, 300, 0) - assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) + Seq(true, false).foreach { executorBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorBroadcast.toString) { + withTempDir { tempDir => + val dir = new File(tempDir, "pqS").getCanonicalPath + + spark.range(10).write.parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("pqS") + + val res3 = InputOutputMetricsHelper.run( + spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() + ) + // The query above is executed in the following stages: + // 1a. sql("select * from pqS") => (10, 0, 10) + // 1b. (only when `SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED` is enabled) + // executor-size-broadcast => (0, 0, 0) + // 2. range(30) => (30, 0, 30) + // 3. crossJoin(...) of 1. and 2. => (0, 30, 300) + // 4. shuffle & return results => (0, 300, 0) + val expected = if (executorBroadcast) { + (10L, 0L, 10L) :: (0L, 0L, 0L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: + (0L, 300L, 0L) :: Nil + } else { + (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil + } + assert(res3 === expected) + } + } } }