Skip to content

Commit 88681bd

Browse files
viiryahuangxiaopingRD
authored andcommitted
[SPARK-54134][SQL] Optimize Arrow memory usage
### What changes were proposed in this pull request? This patch proposes some changes to optimize memory usage on Arrow in Spark. It compress Arrow IPC data when serializing. ### Why are the changes needed? We have encountered OOM when loading data and processing them in PySpark through `toArrow` or `toPandas`. The same data could be loaded by PyArrow directly but fails to load through `toArrow` or `toPandas` into PySpark due to OOM issues. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Manually test it locally. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code v2.0.13 Closes apache#52747 from viirya/release_buffers. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7247261 commit 88681bd

File tree

6 files changed

+137
-2
lines changed

6 files changed

+137
-2
lines changed

dev/deps/spark-deps-hadoop-3-hive-2.3

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ antlr4-runtime/4.13.1//antlr4-runtime-4.13.1.jar
1515
aopalliance-repackaged/3.0.6//aopalliance-repackaged-3.0.6.jar
1616
arpack/3.0.4//arpack-3.0.4.jar
1717
arpack_combined_all/0.1//arpack_combined_all-0.1.jar
18+
arrow-compression/18.3.0//arrow-compression-18.3.0.jar
1819
arrow-format/18.3.0//arrow-format-18.3.0.jar
1920
arrow-memory-core/18.3.0//arrow-memory-core-18.3.0.jar
2021
arrow-memory-netty-buffer-patch/18.3.0//arrow-memory-netty-buffer-patch-18.3.0.jar

pom.xml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,6 +2519,25 @@
25192519
</exclusion>
25202520
</exclusions>
25212521
</dependency>
2522+
<dependency>
2523+
<groupId>org.apache.arrow</groupId>
2524+
<artifactId>arrow-compression</artifactId>
2525+
<version>${arrow.version}</version>
2526+
<exclusions>
2527+
<exclusion>
2528+
<groupId>com.fasterxml.jackson.core</groupId>
2529+
<artifactId>jackson-annotations</artifactId>
2530+
</exclusion>
2531+
<exclusion>
2532+
<groupId>com.fasterxml.jackson.core</groupId>
2533+
<artifactId>jackson-core</artifactId>
2534+
</exclusion>
2535+
<exclusion>
2536+
<groupId>io.netty</groupId>
2537+
<artifactId>netty-common</artifactId>
2538+
</exclusion>
2539+
</exclusions>
2540+
</dependency>
25222541
<dependency>
25232542
<groupId>org.apache.arrow</groupId>
25242543
<artifactId>arrow-memory-netty</artifactId>

python/pyspark/sql/tests/arrow/test_arrow.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,81 @@ def test_createDataFrame_arrow_fixed_size_list(self):
18101810
df = self.spark.createDataFrame(t)
18111811
self.assertIsInstance(df.schema["fsl"].dataType, ArrayType)
18121812

1813+
def test_toPandas_with_compression_codec(self):
1814+
# Test toPandas() with different compression codec settings
1815+
df = self.spark.createDataFrame(self.data, schema=self.schema)
1816+
expected = self.create_pandas_data_frame()
1817+
1818+
for codec in ["none", "zstd", "lz4"]:
1819+
with self.subTest(compressionCodec=codec):
1820+
with self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
1821+
pdf = df.toPandas()
1822+
assert_frame_equal(expected, pdf)
1823+
1824+
def test_toArrow_with_compression_codec(self):
1825+
# Test toArrow() with different compression codec settings
1826+
import pyarrow.compute as pc
1827+
1828+
t_in = self.create_arrow_table()
1829+
1830+
# Convert timezone-naive local timestamp column in input table to UTC
1831+
# to enable comparison to UTC timestamp column in output table
1832+
timezone = self.spark.conf.get("spark.sql.session.timeZone")
1833+
t_in = t_in.set_column(
1834+
t_in.schema.get_field_index("8_timestamp_t"),
1835+
"8_timestamp_t",
1836+
pc.assume_timezone(t_in["8_timestamp_t"], timezone),
1837+
)
1838+
t_in = t_in.cast(
1839+
t_in.schema.set(
1840+
t_in.schema.get_field_index("8_timestamp_t"),
1841+
pa.field("8_timestamp_t", pa.timestamp("us", tz="UTC")),
1842+
)
1843+
)
1844+
1845+
df = self.spark.createDataFrame(self.data, schema=self.schema)
1846+
1847+
for codec in ["none", "zstd", "lz4"]:
1848+
with self.subTest(compressionCodec=codec):
1849+
with self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
1850+
t_out = df.toArrow()
1851+
self.assertTrue(t_out.equals(t_in))
1852+
1853+
def test_toPandas_with_compression_codec_large_dataset(self):
1854+
# Test compression with a larger dataset to verify memory savings
1855+
# Create a dataset with repetitive data that compresses well
1856+
from pyspark.sql.functions import lit, col
1857+
1858+
df = self.spark.range(10000).select(
1859+
col("id"),
1860+
lit("test_string_value_" * 10).alias("str_col"),
1861+
(col("id") % 100).alias("mod_col"),
1862+
)
1863+
1864+
for codec in ["none", "zstd", "lz4"]:
1865+
with self.subTest(compressionCodec=codec):
1866+
with self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
1867+
pdf = df.toPandas()
1868+
self.assertEqual(len(pdf), 10000)
1869+
self.assertEqual(pdf.columns.tolist(), ["id", "str_col", "mod_col"])
1870+
1871+
def test_toArrow_with_compression_codec_large_dataset(self):
1872+
# Test compression with a larger dataset for toArrow
1873+
from pyspark.sql.functions import lit, col
1874+
1875+
df = self.spark.range(10000).select(
1876+
col("id"),
1877+
lit("test_string_value_" * 10).alias("str_col"),
1878+
(col("id") % 100).alias("mod_col"),
1879+
)
1880+
1881+
for codec in ["none", "zstd", "lz4"]:
1882+
with self.subTest(compressionCodec=codec):
1883+
with self.sql_conf({"spark.sql.execution.arrow.compressionCodec": codec}):
1884+
t = df.toArrow()
1885+
self.assertEqual(t.num_rows, 10000)
1886+
self.assertEqual(t.column_names, ["id", "str_col", "mod_col"])
1887+
18131888

18141889
@unittest.skipIf(
18151890
not have_pandas or not have_pyarrow,

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3987,6 +3987,20 @@ object SQLConf {
39873987
"than zero and less than INT_MAX.")
39883988
.createWithDefaultString("64MB")
39893989

3990+
val ARROW_EXECUTION_COMPRESSION_CODEC =
3991+
buildConf("spark.sql.execution.arrow.compressionCodec")
3992+
.doc("Compression codec used to compress Arrow IPC data when transferring data " +
3993+
"between JVM and Python processes (e.g., toPandas, toArrow). This can significantly " +
3994+
"reduce memory usage and network bandwidth when transferring large datasets. " +
3995+
"Supported codecs: 'none' (no compression), 'zstd' (Zstandard), 'lz4' (LZ4). " +
3996+
"Note that compression may add CPU overhead but can provide substantial memory savings " +
3997+
"especially for datasets with high compression ratios.")
3998+
.version("4.1.0")
3999+
.stringConf
4000+
.transform(_.toLowerCase(java.util.Locale.ROOT))
4001+
.checkValues(Set("none", "zstd", "lz4"))
4002+
.createWithDefault("none")
4003+
39904004
val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH =
39914005
buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch")
39924006
.doc("When using TransformWithState in PySpark (both Python Row and Pandas), limit " +
@@ -7332,6 +7346,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
73327346

73337347
def arrowMaxBytesPerBatch: Long = getConf(ARROW_EXECUTION_MAX_BYTES_PER_BATCH)
73347348

7349+
def arrowCompressionCodec: String = getConf(ARROW_EXECUTION_COMPRESSION_CODEC)
7350+
73357351
def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int =
73367352
getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH)
73377353

sql/core/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@
279279
<artifactId>bcpkix-jdk18on</artifactId>
280280
<scope>test</scope>
281281
</dependency>
282+
<dependency>
283+
<groupId>org.apache.arrow</groupId>
284+
<artifactId>arrow-compression</artifactId>
285+
</dependency>
282286
</dependencies>
283287
<build>
284288
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ import java.nio.channels.{Channels, ReadableByteChannel}
2323
import scala.collection.mutable.ArrayBuffer
2424
import scala.jdk.CollectionConverters._
2525

26+
import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
2627
import org.apache.arrow.flatbuf.MessageHeader
2728
import org.apache.arrow.memory.BufferAllocator
2829
import org.apache.arrow.vector._
30+
import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}
2931
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel}
3032
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer}
3133

@@ -37,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
3739
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
3840
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
3941
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
42+
import org.apache.spark.sql.internal.SQLConf
4043
import org.apache.spark.sql.types._
4144
import org.apache.spark.sql.util.ArrowUtils
4245
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -92,8 +95,25 @@ private[sql] object ArrowConverters extends Logging {
9295
ArrowUtils.rootAllocator.newChildAllocator(
9396
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
9497

95-
private val root = VectorSchemaRoot.create(arrowSchema, allocator)
96-
protected val unloader = new VectorUnloader(root)
98+
protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
99+
100+
// Create compression codec based on config
101+
private val compressionCodecName = SQLConf.get.arrowCompressionCodec
102+
private val codec = compressionCodecName match {
103+
case "none" => NoCompressionCodec.INSTANCE
104+
case "zstd" =>
105+
val factory = CompressionCodec.Factory.INSTANCE
106+
val codecType = new ZstdCompressionCodec().getCodecType()
107+
factory.createCodec(codecType)
108+
case "lz4" =>
109+
val factory = CompressionCodec.Factory.INSTANCE
110+
val codecType = new Lz4CompressionCodec().getCodecType()
111+
factory.createCodec(codecType)
112+
case other =>
113+
throw new IllegalArgumentException(
114+
s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4")
115+
}
116+
protected val unloader = new VectorUnloader(root, true, codec, true)
97117
protected val arrowWriter = ArrowWriter.create(root)
98118

99119
Option(context).foreach {_.addTaskCompletionListener[Unit] { _ =>

0 commit comments

Comments
 (0)