From 8522cf7d007ac342cda4982f08c84ea6379e52df Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 1 Aug 2025 09:25:36 -0700 Subject: [PATCH] [SPARK-53001][FOLLOW-UP] Integrate RocksDB Memory Usage with the Unified Memory Manager --- .../spark/internal/config/package.scala | 2 +- .../spark/memory/UnifiedMemoryManager.scala | 56 +-------------- .../memory/UnmanagedMemoryConsumer.scala | 72 +++++++++++++++++++ .../execution/streaming/state/RocksDB.scala | 5 +- .../state/RocksDBMemoryManager.scala | 2 +- .../RocksDBStateStoreIntegrationSuite.scala | 49 +++++-------- 6 files changed, 95 insertions(+), 91 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/memory/UnmanagedMemoryConsumer.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0be2a53d7a0f..c25a4fd45c58 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -508,7 +508,7 @@ package object config { "Setting this to 0 disables unmanaged memory polling.") .version("4.1.0") .timeConf(TimeUnit.MILLISECONDS) - .createWithDefaultString("1s") + .createWithDefaultString("0s") private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD = ConfigBuilder("spark.storage.unrollMemoryThreshold") diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0aec2c232aab..212b54239ee6 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -283,7 +283,7 @@ object UnifiedMemoryManager extends Logging { * @param unmanagedMemoryConsumer The consumer to register for memory tracking */ def registerUnmanagedMemoryConsumer( - unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { + unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId unmanagedMemoryConsumers.put(id, unmanagedMemoryConsumer) } @@ -481,57 +481,3 @@ object UnifiedMemoryManager extends Logging { (usableMemory * memoryFraction).toLong } } - -/** - * Identifier for an unmanaged memory consumer. - * - * @param componentType The type of component (e.g., "RocksDB", "NativeLibrary") - * @param instanceKey A unique key to identify this specific instance of the component. - * For shared memory consumers, this should be a common key across - * all instances to avoid double counting. - */ -case class UnmanagedMemoryConsumerId( - componentType: String, - instanceKey: String - ) - -/** - * Interface for components that consume memory outside of Spark's unified memory management. - * - * Components implementing this trait can register themselves with the memory manager - * to have their memory usage tracked and factored into memory allocation decisions. - * This helps prevent OOM errors when unmanaged components use significant memory. - * - * Examples of unmanaged memory consumers: - * - RocksDB state stores in structured streaming - * - Native libraries with custom memory allocation - * - Off-heap caches managed outside of Spark - */ -trait UnmanagedMemoryConsumer { - /** - * Returns the unique identifier for this memory consumer. - * The identifier is used to track and manage the consumer in the memory tracking system. - */ - def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId - - /** - * Returns the memory mode (ON_HEAP or OFF_HEAP) that this consumer uses. - * This is used to ensure unmanaged memory usage only affects the correct memory pool. - */ - def memoryMode: MemoryMode - - /** - * Returns the current memory usage in bytes. - * - * This method is called periodically by the memory polling mechanism to track - * memory usage over time. Implementations should return the current total memory - * consumed by this component. - * - * @return Current memory usage in bytes. Should return 0 if no memory is currently used. - * Return -1L to indicate this consumer is no longer active and should be - * automatically removed from tracking. - * @throws Exception if memory usage cannot be determined. The polling mechanism - * will handle exceptions gracefully and log warnings. - */ - def getMemBytesUsed: Long -} diff --git a/core/src/main/scala/org/apache/spark/memory/UnmanagedMemoryConsumer.scala b/core/src/main/scala/org/apache/spark/memory/UnmanagedMemoryConsumer.scala new file mode 100644 index 000000000000..835191828215 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/UnmanagedMemoryConsumer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +/** + * Identifier for an unmanaged memory consumer. + * + * @param componentType The type of component (e.g., "RocksDB", "NativeLibrary") + * @param instanceKey A unique key to identify this specific instance of the component. + * For shared memory consumers, this should be a common key across + * all instances to avoid double counting. + */ +case class UnmanagedMemoryConsumerId( + componentType: String, + instanceKey: String +) + +/** + * Interface for components that consume memory outside of Spark's unified memory management. + * + * Components implementing this trait can register themselves with the memory manager + * to have their memory usage tracked and factored into memory allocation decisions. + * This helps prevent OOM errors when unmanaged components use significant memory. + * + * Examples of unmanaged memory consumers: + * - RocksDB state stores in structured streaming + * - Native libraries with custom memory allocation + * - Off-heap caches managed outside of Spark + */ +trait UnmanagedMemoryConsumer { + /** + * Returns the unique identifier for this memory consumer. + * The identifier is used to track and manage the consumer in the memory tracking system. + */ + def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId + + /** + * Returns the memory mode (ON_HEAP or OFF_HEAP) that this consumer uses. + * This is used to ensure unmanaged memory usage only affects the correct memory pool. + */ + def memoryMode: MemoryMode + + /** + * Returns the current memory usage in bytes. + * + * This method is called periodically by the memory polling mechanism to track + * memory usage over time. Implementations should return the current total memory + * consumed by this component. + * + * @return Current memory usage in bytes. Should return 0 if no memory is currently used. + * Return -1L to indicate this consumer is no longer active and should be + * automatically removed from tracking. + * @throws Exception if memory usage cannot be determined. The polling mechanism + * will handle exceptions gracefully and log warnings. + */ + def getMemBytesUsed: Long +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index dd8a99500a1a..641093d7da47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1396,11 +1396,8 @@ class RocksDB( * @return Total memory usage in bytes across all tracked components */ def getMemoryUsage: Long = { - require(db != null && !db.isClosed, "RocksDB must be open to get memory usage") - RocksDB.mainMemorySources.map { memorySource => - getDBProperty(memorySource) - }.sum + RocksDB.mainMemorySources.map(getDBProperty).sum } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala index 80ad864600b2..2fc5c37814a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala @@ -36,7 +36,7 @@ import org.apache.spark.memory.{MemoryMode, UnifiedMemoryManager, UnmanagedMemor * UnifiedMemoryManager, allowing Spark to account for RocksDB memory when making * memory allocation decisions. */ -object RocksDBMemoryManager extends Logging with UnmanagedMemoryConsumer{ +object RocksDBMemoryManager extends Logging with UnmanagedMemoryConsumer { private var writeBufferManager: WriteBufferManager = null private var cache: Cache = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index d2c95dfe5016..9f5256797a7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -21,8 +21,9 @@ import java.io.File import scala.jdk.CollectionConverters.SetHasAsScala -import org.scalatest.time.{Minute, Span} +import org.scalatest.time.{Millis, Minute, Seconds, Span} +import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.functions.{count, max} import org.apache.spark.sql.internal.SQLConf @@ -329,9 +330,6 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest ("spark.sql.streaming.stateStore.rocksdb.boundedMemoryUsage" -> boundedMemoryEnabled.toString)) { - import org.apache.spark.memory.UnifiedMemoryManager - import org.apache.spark.sql.streaming.Trigger - // Use rate stream to ensure continuous state operations that trigger memory updates val query = spark.readStream .format("rate") @@ -347,38 +345,29 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest .start() try { - // Let the stream run to establish RocksDB instances and generate state operations - Thread.sleep(2000) // 2 seconds should be enough for several processing cycles - - // Now check for memory tracking - the continuous stream should trigger memory updates - var rocksDBMemory = 0L - var attempts = 0 - val maxAttempts = 15 // 15 attempts with 1-second intervals = 15 seconds max - - while (rocksDBMemory <= 0L && attempts < maxAttempts) { - Thread.sleep(1000) // Wait between checks to allow memory updates - rocksDBMemory = UnifiedMemoryManager.getMemoryByComponentType("RocksDB") - attempts += 1 - - if (rocksDBMemory > 0L) { - logInfo(s"RocksDB memory detected: $rocksDBMemory bytes " + - s"after $attempts attempts with boundedMemory=$boundedMemoryEnabled") - } + // Check for memory tracking - the continuous stream should trigger memory updates + var initialRocksDBMemory = 0L + eventually(timeout(Span(20, Seconds)), interval(Span(500, Millis))) { + initialRocksDBMemory = UnifiedMemoryManager.getMemoryByComponentType("RocksDB") + assert(initialRocksDBMemory > 0L, + s"RocksDB memory should be tracked with boundedMemory=$boundedMemoryEnabled") } + logInfo(s"RocksDB memory detected: $initialRocksDBMemory bytes " + + s"with boundedMemory=$boundedMemoryEnabled") + // Verify memory tracking remains stable during continued operation - Thread.sleep(2000) // Let stream continue running + eventually(timeout(Span(5, Seconds)), interval(Span(500, Millis))) { + val currentMemory = UnifiedMemoryManager.getMemoryByComponentType("RocksDB") + assert(currentMemory > 0L, + s"RocksDB memory tracking should remain active during stream processing: " + + s"got $currentMemory bytes (initial: $initialRocksDBMemory) " + + s"with boundedMemory=$boundedMemoryEnabled") + } val finalMemory = UnifiedMemoryManager.getMemoryByComponentType("RocksDB") - - // Memory should still be tracked (allow for some fluctuation) - assert(finalMemory > 0L, - s"RocksDB memory tracking should remain active during stream processing: " + - s"got $finalMemory bytes (initial: $rocksDBMemory) " + - s"with boundedMemory=$boundedMemoryEnabled") - logInfo(s"RocksDB memory tracking test completed successfully: " + - s"initial=$rocksDBMemory bytes, final=$finalMemory bytes " + + s"initial=$initialRocksDBMemory bytes, final=$finalMemory bytes " + s"with boundedMemory=$boundedMemoryEnabled") } finally {