diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index eee6e4b28ac4..62d60475985b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -81,4 +81,8 @@ private[spark] object PythonUtils { def isEncryptionEnabled(sc: JavaSparkContext): Boolean = { sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED) } + + def getBroadcastThreshold(sc: JavaSparkContext): Long = { + sc.conf.get(org.apache.spark.internal.config.BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 488886f1627f..76d3d6ee3d8f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1246,6 +1246,14 @@ package object config { "mechanisms to guarantee data won't be corrupted during broadcast") .booleanConf.createWithDefault(true) + private[spark] val BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD = + ConfigBuilder("spark.broadcast.UDFCompressionThreshold") + .doc("The threshold at which user-defined functions (UDFs) and Python RDD commands " + + "are compressed by broadcast in bytes unless otherwise specified") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v >= 0, "The threshold should be non-negative.") + .createWithDefault(1L * 1024 * 1024) + private[spark] val RDD_COMPRESS = ConfigBuilder("spark.rdd.compress") .doc("Whether to compress serialized RDD partitions " + "(e.g. for StorageLevel.MEMORY_ONLY_SER in Scala " + diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 6be1fedc123d..202b85dcf569 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -389,6 +389,19 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst """.stripMargin.trim) } + test("SPARK-28355: Use Spark conf for threshold at which UDFs are compressed by broadcast") { + val conf = new SparkConf() + + // Check the default value + assert(conf.get(BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) === 1L * 1024 * 1024) + + // Set the conf + conf.set(BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD, 1L * 1024) + + // Verify that it has been set properly + assert(conf.get(BROADCAST_FOR_UDF_COMPRESSION_THRESHOLD) === 1L * 1024) + } + val defaultIllegalValue = "SomeIllegalValue" val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8bcc67ab1c3e..96fdf5f33b39 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2490,7 +2490,7 @@ def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M + if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc): # Default 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast)