diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 2c7c2f120b93..0be2a53d7a0f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -500,6 +500,16 @@ package object config { .doubleConf .createWithDefault(0.6) + private[spark] val UNMANAGED_MEMORY_POLLING_INTERVAL = + ConfigBuilder("spark.memory.unmanagedMemoryPollingInterval") + .doc("Interval for polling unmanaged memory users to track their memory usage. " + + "Unmanaged memory users are components that manage their own memory outside of " + + "Spark's core memory management, such as RocksDB for Streaming State Store. " + + "Setting this to 0 disables unmanaged memory polling.") + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("1s") + private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD = ConfigBuilder("spark.storage.unrollMemoryThreshold") .doc("Initial memory to request before unrolling any block") diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index d4ec6ed8495a..0aec2c232aab 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -17,11 +17,19 @@ package org.apache.spark.memory +import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal + import org.apache.spark.{SparkConf, SparkIllegalArgumentException} -import org.apache.spark.internal.{config, MDC} +import org.apache.spark.internal.{config, Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config.Tests._ +import org.apache.spark.internal.config.UNMANAGED_MEMORY_POLLING_INTERVAL import org.apache.spark.storage.BlockId +import org.apache.spark.util.{ThreadUtils, Utils} /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that @@ -56,7 +64,47 @@ private[spark] class UnifiedMemoryManager( conf, numCores, onHeapStorageRegionSize, - maxHeapMemory - onHeapStorageRegionSize) { + maxHeapMemory - onHeapStorageRegionSize) with Logging { + + /** + * Unmanaged memory tracking infrastructure. + * + * Unmanaged memory refers to memory consumed by components that manage their own memory + * outside of Spark's unified memory management system. Examples include: + * - RocksDB state stores used in structured streaming + * - Native libraries with their own memory management + * - Off-heap caches managed by unmanaged systems + * + * We track this memory to: + * 1. Provide visibility into total memory usage on executors + * 2. Prevent OOM errors by accounting for it in memory allocation decisions + * 3. Enable better debugging and monitoring of memory-intensive applications + * + * The polling mechanism periodically queries registered unmanaged memory consumers + * to detect inactive consumers and handle cleanup. + */ + // Configuration for polling interval (in milliseconds) + private val unmanagedMemoryPollingIntervalMs = conf.get(UNMANAGED_MEMORY_POLLING_INTERVAL) + // Initialize background polling if enabled + if (unmanagedMemoryPollingIntervalMs > 0) { + UnifiedMemoryManager.startPollingIfNeeded(unmanagedMemoryPollingIntervalMs) + } + + /** + * Get the current unmanaged memory usage in bytes for a specific memory mode. + * @param memoryMode The memory mode (ON_HEAP or OFF_HEAP) to get usage for + * @return The current unmanaged memory usage in bytes + */ + private def getUnmanagedMemoryUsed(memoryMode: MemoryMode): Long = { + // Only consider unmanaged memory if polling is enabled + if (unmanagedMemoryPollingIntervalMs <= 0) { + return 0L + } + memoryMode match { + case MemoryMode.ON_HEAP => UnifiedMemoryManager.unmanagedOnHeapUsed.get() + case MemoryMode.OFF_HEAP => UnifiedMemoryManager.unmanagedOffHeapUsed.get() + } + } private def assertInvariants(): Unit = { assert(onHeapExecutionMemoryPool.poolSize + onHeapStorageMemoryPool.poolSize == maxHeapMemory) @@ -140,9 +188,15 @@ private[spark] class UnifiedMemoryManager( * in execution memory allocation across tasks, Otherwise, a task may occupy more than * its fair share of execution memory, mistakenly thinking that other tasks can acquire * the portion of storage memory that cannot be evicted. + * + * This also factors in unmanaged memory usage to ensure we don't over-allocate memory + * when unmanaged components are consuming significant memory. */ def computeMaxExecutionPoolSize(): Long = { - maxMemory - math.min(storagePool.memoryUsed, storageRegionSize) + val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode) + val availableMemory = maxMemory - math.min(storagePool.memoryUsed, storageRegionSize) + // Reduce available memory by unmanaged memory usage to prevent over-allocation + math.max(0L, availableMemory - unmanagedMemory) } executionPool.acquireMemory( @@ -165,11 +219,21 @@ private[spark] class UnifiedMemoryManager( offHeapStorageMemoryPool, maxOffHeapStorageMemory) } - if (numBytes > maxMemory) { + + // Factor in unmanaged memory usage for the specific memory mode + val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode) + val effectiveMaxMemory = math.max(0L, maxMemory - unmanagedMemory) + + if (numBytes > effectiveMaxMemory) { // Fail fast if the block simply won't fit logInfo(log"Will not store ${MDC(BLOCK_ID, blockId)} as the required space" + log" (${MDC(NUM_BYTES, numBytes)} bytes) exceeds our" + - log" memory limit (${MDC(NUM_BYTES_MAX, maxMemory)} bytes)") + log" memory limit (${MDC(NUM_BYTES_MAX, effectiveMaxMemory)} bytes)" + + (if (unmanagedMemory > 0) { + log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanagedMemory)} bytes)" + } else { + log"" + })) return false } if (numBytes > storagePool.memoryFree) { @@ -191,7 +255,7 @@ private[spark] class UnifiedMemoryManager( } } -object UnifiedMemoryManager { +object UnifiedMemoryManager extends Logging { // Set aside a fixed amount of memory for non-storage, non-execution purposes. // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve @@ -199,6 +263,181 @@ object UnifiedMemoryManager { // the memory used for execution and storage will be (1024 - 300) * 0.6 = 434MB by default. private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 + private val unmanagedMemoryConsumers = + new ConcurrentHashMap[UnmanagedMemoryConsumerId, UnmanagedMemoryConsumer] + + // Cached unmanaged memory usage values updated by polling + private val unmanagedOnHeapUsed = new AtomicLong(0L) + private val unmanagedOffHeapUsed = new AtomicLong(0L) + + // Atomic flag to ensure polling is only started once per JVM + private val pollingStarted = new AtomicBoolean(false) + + /** + * Register an unmanaged memory consumer to track its memory usage. + * + * Unmanaged memory consumers are components that manage their own memory outside + * of Spark's unified memory management system. By registering, their memory usage + * will be periodically polled and factored into Spark's memory allocation decisions. + * + * @param unmanagedMemoryConsumer The consumer to register for memory tracking + */ + def registerUnmanagedMemoryConsumer( + unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { + val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId + unmanagedMemoryConsumers.put(id, unmanagedMemoryConsumer) + } + + /** + * Unregister an unmanaged memory consumer. + * This should be called when a component is shutting down to prevent memory leaks + * and ensure accurate memory tracking. + * + * @param unmanagedMemoryConsumer The consumer to unregister. Only used in tests + */ + private[spark] def unregisterUnmanagedMemoryConsumer( + unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { + val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId + unmanagedMemoryConsumers.remove(id) + } + + + /** + * Get the current memory usage in bytes for a specific component type. + * @param componentType The type of component to filter by (e.g., "RocksDB") + * @return Total memory usage in bytes for the specified component type + */ + def getMemoryByComponentType(componentType: String): Long = { + unmanagedMemoryConsumers.asScala.values.toSeq + .filter(_.unmanagedMemoryConsumerId.componentType == componentType) + .map { memoryUser => + try { + memoryUser.getMemBytesUsed + } catch { + case e: Exception => + 0L + } + } + .sum + } + + /** + * Clear all unmanaged memory users. + * This is useful during executor shutdown or cleanup. + * Since each executor runs in its own JVM, this clears all users for this executor. + */ + def clearUnmanagedMemoryUsers(): Unit = { + unmanagedMemoryConsumers.clear() + // Reset cached values when clearing consumers + unmanagedOnHeapUsed.set(0L) + unmanagedOffHeapUsed.set(0L) + } + + // Shared polling infrastructure - only one polling thread per JVM + @volatile private var unmanagedMemoryPoller: ScheduledExecutorService = _ + + /** + * Start unmanaged memory polling if not already started. + * This ensures only one polling thread is created per JVM, regardless of how many + * UnifiedMemoryManager instances are created. + */ + private[memory] def startPollingIfNeeded(pollingIntervalMs: Long): Unit = { + if (pollingStarted.compareAndSet(false, true)) { + unmanagedMemoryPoller = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "unmanaged-memory-poller") + + val pollingTask = new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + pollUnmanagedMemoryUsers() + } + } + + unmanagedMemoryPoller.scheduleAtFixedRate( + pollingTask, + 0L, // initial delay + pollingIntervalMs, + TimeUnit.MILLISECONDS) + + logInfo(log"Unmanaged memory polling started with interval " + + log"${MDC(LogKeys.TIME, pollingIntervalMs)}ms") + } + } + + private def pollUnmanagedMemoryUsers(): Unit = { + val consumers = unmanagedMemoryConsumers.asScala.toMap + + // Get memory usage for each consumer, handling failures gracefully + val memoryUsages = consumers.map { case (userId, memoryUser) => + try { + val memoryUsed = memoryUser.getMemBytesUsed + if (memoryUsed == -1L) { + logDebug(log"Unmanaged memory consumer ${MDC(LogKeys.OBJECT_ID, userId.toString)} " + + log"is no longer active, marking for removal") + (userId, memoryUser, None) // Mark for removal + } else if (memoryUsed < 0L) { + logWarning(log"Invalid memory usage value ${MDC(LogKeys.NUM_BYTES, memoryUsed)} " + + log"from unmanaged memory user ${MDC(LogKeys.OBJECT_ID, userId.toString)}") + (userId, memoryUser, Some(0L)) // Treat as 0 + } else { + (userId, memoryUser, Some(memoryUsed)) + } + } catch { + case NonFatal(e) => + logWarning(log"Failed to get memory usage for unmanaged memory user " + + log"${MDC(LogKeys.OBJECT_ID, userId.toString)} ${MDC(LogKeys.EXCEPTION, e)}") + (userId, memoryUser, Some(0L)) // Treat as 0 on error + } + } + + // Remove inactive consumers + memoryUsages.filter(_._3.isEmpty).foreach { case (userId, _, _) => + unmanagedMemoryConsumers.remove(userId) + logInfo(log"Removed inactive unmanaged memory consumer " + + log"${MDC(LogKeys.OBJECT_ID, userId.toString)}") + } + // Calculate total memory usage by mode + val activeUsages = memoryUsages.filter(_._3.isDefined) + val onHeapTotal = activeUsages + .filter(_._2.memoryMode == MemoryMode.ON_HEAP) + .map(_._3.get) + .sum + val offHeapTotal = activeUsages + .filter(_._2.memoryMode == MemoryMode.OFF_HEAP) + .map(_._3.get) + .sum + // Update cached values atomically + unmanagedOnHeapUsed.set(onHeapTotal) + unmanagedOffHeapUsed.set(offHeapTotal) + // Log polling results for monitoring + val totalMemoryUsed = onHeapTotal + offHeapTotal + val numConsumers = activeUsages.size + logDebug(s"Unmanaged memory polling completed: $numConsumers consumers, " + + s"total memory used: ${totalMemoryUsed} bytes " + + s"(on-heap: ${onHeapTotal}, off-heap: ${offHeapTotal})") + } + + /** + * Shutdown the unmanaged memory polling thread. Only used in tests + */ + private[spark] def shutdownUnmanagedMemoryPoller(): Unit = { + synchronized { + if (unmanagedMemoryPoller != null) { + unmanagedMemoryPoller.shutdown() + try { + if (!unmanagedMemoryPoller.awaitTermination(5, TimeUnit.SECONDS)) { + unmanagedMemoryPoller.shutdownNow() + } + } catch { + case _: InterruptedException => + Thread.currentThread().interrupt() + } + unmanagedMemoryPoller = null + pollingStarted.set(false) + logInfo(log"Unmanaged memory poller shutdown complete") + } + } + } + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { val maxMemory = getMaxMemory(conf) new UnifiedMemoryManager( @@ -242,3 +481,57 @@ object UnifiedMemoryManager { (usableMemory * memoryFraction).toLong } } + +/** + * Identifier for an unmanaged memory consumer. + * + * @param componentType The type of component (e.g., "RocksDB", "NativeLibrary") + * @param instanceKey A unique key to identify this specific instance of the component. + * For shared memory consumers, this should be a common key across + * all instances to avoid double counting. + */ +case class UnmanagedMemoryConsumerId( + componentType: String, + instanceKey: String + ) + +/** + * Interface for components that consume memory outside of Spark's unified memory management. + * + * Components implementing this trait can register themselves with the memory manager + * to have their memory usage tracked and factored into memory allocation decisions. + * This helps prevent OOM errors when unmanaged components use significant memory. + * + * Examples of unmanaged memory consumers: + * - RocksDB state stores in structured streaming + * - Native libraries with custom memory allocation + * - Off-heap caches managed outside of Spark + */ +trait UnmanagedMemoryConsumer { + /** + * Returns the unique identifier for this memory consumer. + * The identifier is used to track and manage the consumer in the memory tracking system. + */ + def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId + + /** + * Returns the memory mode (ON_HEAP or OFF_HEAP) that this consumer uses. + * This is used to ensure unmanaged memory usage only affects the correct memory pool. + */ + def memoryMode: MemoryMode + + /** + * Returns the current memory usage in bytes. + * + * This method is called periodically by the memory polling mechanism to track + * memory usage over time. Implementations should return the current total memory + * consumed by this component. + * + * @return Current memory usage in bytes. Should return 0 if no memory is currently used. + * Return -1L to indicate this consumer is no longer active and should be + * automatically removed from tracking. + * @throws Exception if memory usage cannot be determined. The polling mechanism + * will handle exceptions gracefully and log warnings. + */ + def getMemBytesUsed: Long +} diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 0cafe6891c7d..9c74f2fdd459 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -340,5 +340,289 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) assertEvictBlocksToFreeSpaceCalled(ms, 50) assert(mm.storageMemoryUsed === 600L) + UnifiedMemoryManager.shutdownUnmanagedMemoryPoller() + } + + test("unmanaged memory tracking with memory mode separation") { + val maxMemory = 1000L + val taskAttemptId = 0L + val conf = new SparkConf() + .set(MEMORY_FRACTION, 1.0) + .set(TEST_MEMORY, maxMemory) + .set(MEMORY_OFFHEAP_ENABLED, false) + .set(MEMORY_STORAGE_FRACTION, storageFraction) + .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L) // 100ms polling + val mm = UnifiedMemoryManager(conf, numCores = 1) + val memoryMode = MemoryMode.ON_HEAP + + // Mock unmanaged memory consumer for ON_HEAP + class MockOnHeapMemoryConsumer(var memoryUsed: Long) extends UnmanagedMemoryConsumer { + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("TestOnHeap", "test-instance") + override def memoryMode: MemoryMode = MemoryMode.ON_HEAP + override def getMemBytesUsed: Long = memoryUsed + } + + // Mock unmanaged memory consumer for OFF_HEAP + class MockOffHeapMemoryConsumer(var memoryUsed: Long) extends UnmanagedMemoryConsumer { + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("TestOffHeap", "test-instance") + override def memoryMode: MemoryMode = MemoryMode.OFF_HEAP + override def getMemBytesUsed: Long = memoryUsed + } + + val onHeapConsumer = new MockOnHeapMemoryConsumer(0L) + val offHeapConsumer = new MockOffHeapMemoryConsumer(0L) + + try { + // Register both consumers + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(onHeapConsumer) + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(offHeapConsumer) + + // Initially no unmanaged memory usage + assert(UnifiedMemoryManager.getMemoryByComponentType("TestOnHeap") === 0L) + assert(UnifiedMemoryManager.getMemoryByComponentType("TestOffHeap") === 0L) + + // Set off-heap memory usage - this should NOT affect on-heap allocations + offHeapConsumer.memoryUsed = 200L + + // Wait for polling to pick up the change + Thread.sleep(200) + + // Test that off-heap unmanaged memory doesn't affect on-heap execution memory allocation + val acquiredMemory = mm.acquireExecutionMemory(1000L, taskAttemptId, memoryMode) + // Should get full 1000 bytes since off-heap unmanaged memory doesn't affect on-heap pool + assert(acquiredMemory == 1000L) + + // Release execution memory + mm.releaseExecutionMemory(acquiredMemory, taskAttemptId, memoryMode) + + // Now set on-heap memory usage - this SHOULD affect on-heap allocations + onHeapConsumer.memoryUsed = 200L + Thread.sleep(200) + + // Test that on-heap unmanaged memory affects on-heap execution memory allocation + val acquiredMemory2 = mm.acquireExecutionMemory(900L, taskAttemptId, memoryMode) + // Should only get 800 bytes due to 200 bytes of on-heap unmanaged memory usage + assert(acquiredMemory2 == 800L) + + // Release execution memory to test storage allocation + mm.releaseExecutionMemory(acquiredMemory2, taskAttemptId, memoryMode) + + // Test storage memory with on-heap unmanaged memory consideration + onHeapConsumer.memoryUsed = 300L + Thread.sleep(200) + + // Storage should fail when block size + unmanaged memory > max memory + assert(!mm.acquireStorageMemory(dummyBlock, 800L, memoryMode)) + + // But smaller storage requests should succeed with unmanaged memory factored in + // With 300L on-heap unmanaged memory, effective max is 700L + assert(mm.acquireStorageMemory(dummyBlock, 600L, memoryMode)) + + } finally { + UnifiedMemoryManager.shutdownUnmanagedMemoryPoller() + UnifiedMemoryManager.clearUnmanagedMemoryUsers() + } + } + + test("unmanaged memory consumer registration and unregistration") { + val conf = new SparkConf() + .set(MEMORY_FRACTION, 1.0) + .set(TEST_MEMORY, 1000L) + .set(MEMORY_OFFHEAP_ENABLED, false) + .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L) + + val mm = UnifiedMemoryManager(conf, numCores = 1) + + class MockMemoryConsumer( + var memoryUsed: Long, + instanceId: String, + mode: MemoryMode = MemoryMode.ON_HEAP) extends UnmanagedMemoryConsumer { + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("Test", instanceId) + override def memoryMode: MemoryMode = mode + override def getMemBytesUsed: Long = memoryUsed + } + + val consumer1 = new MockMemoryConsumer(100L, "test-instance-1") + val consumer2 = new MockMemoryConsumer(200L, "test-instance-2") + + try { + // Register consumers + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer1) + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer2) + + Thread.sleep(200) + assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 300L) + + // Unregister one consumer + UnifiedMemoryManager.unregisterUnmanagedMemoryConsumer(consumer1) + + Thread.sleep(200) + assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 200L) + + // Unregister second consumer + UnifiedMemoryManager.unregisterUnmanagedMemoryConsumer(consumer2) + + Thread.sleep(200) + assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 0L) + + } finally { + UnifiedMemoryManager.shutdownUnmanagedMemoryPoller() + UnifiedMemoryManager.clearUnmanagedMemoryUsers() + } + } + + test("unmanaged memory consumer auto-removal when returning -1") { + val conf = new SparkConf() + .set(MEMORY_FRACTION, 1.0) + .set(TEST_MEMORY, 1000L) + .set(MEMORY_OFFHEAP_ENABLED, false) + .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L) + + val mm = UnifiedMemoryManager(conf, numCores = 1) + + class MockMemoryConsumer(var memoryUsed: Long) extends UnmanagedMemoryConsumer { + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("Test", s"test-instance-${this.hashCode()}") + override def memoryMode: MemoryMode = MemoryMode.ON_HEAP + override def getMemBytesUsed: Long = memoryUsed + } + + val consumer1 = new MockMemoryConsumer(100L) + val consumer2 = new MockMemoryConsumer(200L) + + try { + // Register consumers + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer1) + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer2) + + Thread.sleep(200) + assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 300L) + + // Mark consumer1 as inactive + consumer1.memoryUsed = -1L + + // Wait for polling to detect and remove the inactive consumer + Thread.sleep(200) + assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 200L) + + // Mark consumer2 as inactive as well + consumer2.memoryUsed = -1L + + Thread.sleep(200) + assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 0L) + + } finally { + UnifiedMemoryManager.shutdownUnmanagedMemoryPoller() + UnifiedMemoryManager.clearUnmanagedMemoryUsers() + } + } + + test("unmanaged memory polling disabled when interval is zero") { + val conf = new SparkConf() + .set(MEMORY_FRACTION, 1.0) + .set(TEST_MEMORY, 1000L) + .set(MEMORY_OFFHEAP_ENABLED, false) + .set(MEMORY_STORAGE_FRACTION, storageFraction) + .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 0L) // Disabled + + val mm = UnifiedMemoryManager(conf, numCores = 1) + + // When polling is disabled, unmanaged memory should not affect allocations + class MockUnmanagedMemoryConsumer(var memoryUsed: Long) extends UnmanagedMemoryConsumer { + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("Test", "test-instance") + override def memoryMode: MemoryMode = MemoryMode.ON_HEAP + override def getMemBytesUsed: Long = memoryUsed + } + + val consumer = new MockUnmanagedMemoryConsumer(500L) + + try { + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer) + + // Since polling is disabled, should be able to allocate full memory + val acquiredMemory = mm.acquireExecutionMemory(1000L, 0L, MemoryMode.ON_HEAP) + assert(acquiredMemory === 1000L) + + } finally { + UnifiedMemoryManager.shutdownUnmanagedMemoryPoller() + UnifiedMemoryManager.clearUnmanagedMemoryUsers() + } + } + + test("unmanaged memory tracking with off-heap memory enabled") { + val maxOnHeapMemory = 1000L + val maxOffHeapMemory = 1500L + val taskAttemptId = 0L + val conf = new SparkConf() + .set(MEMORY_FRACTION, 1.0) + .set(TEST_MEMORY, maxOnHeapMemory) + .set(MEMORY_OFFHEAP_ENABLED, true) + .set(MEMORY_OFFHEAP_SIZE, maxOffHeapMemory) + .set(MEMORY_STORAGE_FRACTION, storageFraction) + .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L) + val mm = UnifiedMemoryManager(conf, numCores = 1) + + // Mock unmanaged memory consumer + class MockUnmanagedMemoryConsumer(var memoryUsed: Long) extends UnmanagedMemoryConsumer { + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("ExternalLib", "test-instance") + + override def memoryMode: MemoryMode = MemoryMode.OFF_HEAP + + override def getMemBytesUsed: Long = memoryUsed + } + + val unmanagedConsumer = new MockUnmanagedMemoryConsumer(0L) + + try { + // Register the unmanaged memory consumer + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(unmanagedConsumer) + + // Test off-heap memory allocation with unmanaged memory + unmanagedConsumer.memoryUsed = 300L + Thread.sleep(200) + + // Test off-heap execution memory + // With 300 bytes of unmanaged memory, effective off-heap memory should be reduced + val offHeapAcquired = mm.acquireExecutionMemory(1400L, taskAttemptId, MemoryMode.OFF_HEAP) + assert(offHeapAcquired <= 1200L, "Off-heap memory should be reduced by unmanaged usage") + mm.releaseExecutionMemory(offHeapAcquired, taskAttemptId, MemoryMode.OFF_HEAP) + + // Test off-heap storage memory + unmanagedConsumer.memoryUsed = 500L + Thread.sleep(200) + + // Storage should fail when block size + unmanaged memory > max off-heap memory + assert(!mm.acquireStorageMemory(dummyBlock, 1100L, MemoryMode.OFF_HEAP)) + + // But smaller off-heap storage requests should succeed + assert(mm.acquireStorageMemory(dummyBlock, 900L, MemoryMode.OFF_HEAP)) + mm.releaseStorageMemory(900L, MemoryMode.OFF_HEAP) + + // Test that on-heap is NOT affected by off-heap unmanaged memory + val onHeapAcquired = mm.acquireExecutionMemory(600L, taskAttemptId, MemoryMode.ON_HEAP) + assert(onHeapAcquired == 600L, + "On-heap memory should not be reduced by off-heap unmanaged usage") + mm.releaseExecutionMemory(onHeapAcquired, taskAttemptId, MemoryMode.ON_HEAP) + + // Test with mixed memory modes + unmanagedConsumer.memoryUsed = 200L + Thread.sleep(200) + + // Allocate some on-heap and off-heap memory + val onHeap = mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) + val offHeap = mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.OFF_HEAP) + + assert(onHeap == 400L && offHeap <= 1300L, + "Off-heap memory pool should respect unmanaged memory usage, on-heap should not") + + } finally { + UnifiedMemoryManager.shutdownUnmanagedMemoryPoller() + UnifiedMemoryManager.clearUnmanagedMemoryUsers() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 90662cbb6ca7..dd8a99500a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -27,6 +27,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong} import scala.collection.{mutable, Map} import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala import scala.util.Try +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.json4s.{Formats, NoTypeHints} @@ -73,7 +74,8 @@ class RocksDB( useColumnFamilies: Boolean = false, enableStateStoreCheckpointIds: Boolean = false, partitionId: Int = 0, - eventForwarder: Option[RocksDBEventForwarder] = None) extends Logging { + eventForwarder: Option[RocksDBEventForwarder] = None, + uniqueId: String = "") extends Logging { import RocksDB._ @@ -181,6 +183,24 @@ class RocksDB( protected var sessionStateStoreCkptId: Option[String] = None protected[sql] val lineageManager: RocksDBLineageManager = new RocksDBLineageManager + // Memory tracking fields for unmanaged memory monitoring + // This allows the UnifiedMemoryManager to track RocksDB memory usage without + // directly accessing RocksDB from the polling thread, avoiding segmentation faults + + // Timestamp of the last memory usage update in milliseconds. + // Used to enforce the update interval and prevent excessive memory queries. + private val lastMemoryUpdateTime = new AtomicLong(0L) + + // Minimum interval between memory usage updates in milliseconds. + // This prevents performance impact from querying RocksDB memory too frequently. + private val memoryUpdateIntervalMs = conf.memoryUpdateIntervalMs + + // Register with RocksDBMemoryManager if we have a unique ID + if (uniqueId.nonEmpty) { + // Initial registration with zero memory usage + RocksDBMemoryManager.updateMemoryUsage(uniqueId, 0L, conf.boundedMemoryUsage) + } + @volatile private var numKeysOnLoadedVersion = 0L @volatile private var numKeysOnWritingVersion = 0L @@ -573,6 +593,10 @@ class RocksDB( } else { loadWithoutCheckpointId(version, readOnly) } + + // Register with memory manager after successful load + updateMemoryUsageIfNeeded() + this } @@ -754,6 +778,7 @@ class RocksDB( def get( key: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = { + updateMemoryUsageIfNeeded() val keyWithPrefix = if (useColumnFamilies) { encodeStateRowWithPrefix(key, cfName) } else { @@ -821,6 +846,7 @@ class RocksDB( value: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME, includesPrefix: Boolean = false): Unit = { + updateMemoryUsageIfNeeded() val keyWithPrefix = if (useColumnFamilies && !includesPrefix) { encodeStateRowWithPrefix(key, cfName) } else { @@ -848,6 +874,7 @@ class RocksDB( value: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME, includesPrefix: Boolean = false): Unit = { + updateMemoryUsageIfNeeded() val keyWithPrefix = if (useColumnFamilies && !includesPrefix) { encodeStateRowWithPrefix(key, cfName) } else { @@ -867,6 +894,7 @@ class RocksDB( key: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME, includesPrefix: Boolean = false): Unit = { + updateMemoryUsageIfNeeded() val keyWithPrefix = if (useColumnFamilies && !includesPrefix) { encodeStateRowWithPrefix(key, cfName) } else { @@ -882,6 +910,7 @@ class RocksDB( * Get an iterator of all committed and uncommitted key-value pairs. */ def iterator(): Iterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() val iter = db.newIterator() logInfo(log"Getting iterator from version ${MDC(LogKeys.LOADED_VERSION, loadedVersion)}") iter.seekToFirst() @@ -918,6 +947,7 @@ class RocksDB( * Get an iterator of all committed and uncommitted key-value pairs for the given column family. */ def iterator(cfName: String): Iterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() if (!useColumnFamilies) { iterator() } else { @@ -967,6 +997,7 @@ class RocksDB( def prefixScan( prefix: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() val iter = db.newIterator() val updatedPrefix = if (useColumnFamilies) { encodeStateRowWithPrefix(prefix, cfName) @@ -1013,6 +1044,7 @@ class RocksDB( * - Sync the checkpoint dir files to DFS */ def commit(): (Long, StateStoreCheckpointInfo) = { + updateMemoryUsageIfNeeded() val newVersion = loadedVersion + 1 try { logInfo(log"Flushing updates for ${MDC(LogKeys.VERSION_NUM, newVersion)}") @@ -1227,6 +1259,17 @@ class RocksDB( snapshot = snapshotsToUploadQueue.poll() } + // Unregister from RocksDBMemoryManager + if (uniqueId.nonEmpty) { + try { + RocksDBMemoryManager.unregisterInstance(uniqueId) + } catch { + case NonFatal(e) => + logWarning(log"Failed to unregister from RocksDBMemoryManager " + + log"${MDC(LogKeys.EXCEPTION, e)}") + } + } + silentDeleteRecursively(localRootDir, "closing RocksDB") // Clear internal maps to reset the state clearColFamilyMaps() @@ -1339,6 +1382,53 @@ class RocksDB( private def getDBProperty(property: String): Long = db.getProperty(property).toLong + /** + * Returns the current memory usage of this RocksDB instance in bytes. + * WARNING: This method should only be called from the task thread when + * RocksDB is in a safe state. + * + * This includes memory from all major RocksDB components: + * - Table readers (indexes and filters in memory) + * - Memtables (write buffers) + * - Block cache (cached data blocks) + * - Block cache pinned usage (blocks pinned in cache) + * + * @return Total memory usage in bytes across all tracked components + */ + def getMemoryUsage: Long = { + + require(db != null && !db.isClosed, "RocksDB must be open to get memory usage") + RocksDB.mainMemorySources.map { memorySource => + getDBProperty(memorySource) + }.sum + } + + /** + * Updates the cached memory usage if enough time has passed. + * This is called from task thread operations, so it's already thread-safe. + */ + def updateMemoryUsageIfNeeded(): Unit = { + if (uniqueId.isEmpty) return // No tracking without unique ID + + val currentTime = System.currentTimeMillis() + val timeSinceLastUpdate = currentTime - lastMemoryUpdateTime.get() + + if (timeSinceLastUpdate >= memoryUpdateIntervalMs && db != null && !db.isClosed) { + try { + val usage = getMemoryUsage + lastMemoryUpdateTime.set(currentTime) + // Report usage to RocksDBMemoryManager + RocksDBMemoryManager.updateMemoryUsage( + uniqueId, + usage, + conf.boundedMemoryUsage) + } catch { + case NonFatal(e) => + logDebug(s"Failed to update RocksDB memory usage: ${e.getMessage}") + } + } + } + private def openDB(): Unit = { assert(db == null) db = NativeRocksDB.open(rocksDbOptions, workingDir.toString) @@ -1458,6 +1548,13 @@ class RocksDB( } object RocksDB extends Logging { + + val mainMemorySources: Seq[String] = Seq( + "rocksdb.estimate-table-readers-mem", + "rocksdb.cur-size-all-mem-tables", + "rocksdb.block-cache-usage", + "rocksdb.block-cache-pinned-usage") + case class RocksDBSnapshot( checkpointDir: File, version: Long, @@ -1699,6 +1796,7 @@ case class RocksDBConf( totalMemoryUsageMB: Long, writeBufferCacheRatio: Double, highPriorityPoolRatio: Double, + memoryUpdateIntervalMs: Long, compressionCodec: String, allowFAllocate: Boolean, compression: String, @@ -1785,6 +1883,12 @@ object RocksDBConf { private val HIGH_PRIORITY_POOL_RATIO_CONF = SQLConfEntry(HIGH_PRIORITY_POOL_RATIO_CONF_KEY, "0.1") + // Memory usage update interval for unmanaged memory tracking + val MEMORY_UPDATE_INTERVAL_MS_CONF_KEY = "memoryUpdateIntervalMs" + private val MEMORY_UPDATE_INTERVAL_MS_CONF = SQLConfEntry(MEMORY_UPDATE_INTERVAL_MS_CONF_KEY, + "1000") + + // Allow files to be pre-allocated on disk using fallocate // Disabling may slow writes, but can solve an issue where // significant quantities of disk are wasted if there are @@ -1883,6 +1987,7 @@ object RocksDBConf { getLongConf(MAX_MEMORY_USAGE_MB_CONF), getRatioConf(WRITE_BUFFER_CACHE_RATIO_CONF), getRatioConf(HIGH_PRIORITY_POOL_RATIO_CONF), + getPositiveLongConf(MEMORY_UPDATE_INTERVAL_MS_CONF), storeConf.compressionCodec, getBooleanConf(ALLOW_FALLOCATE_CONF), getStringConf(COMPRESSION_CONF), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala index 273cbbc5e87d..80ad864600b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala @@ -17,22 +17,93 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.concurrent.ConcurrentHashMap + +import scala.jdk.CollectionConverters._ + import org.rocksdb._ +import org.apache.spark.SparkEnv import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ +import org.apache.spark.memory.{MemoryMode, UnifiedMemoryManager, UnmanagedMemoryConsumer, UnmanagedMemoryConsumerId} /** * Singleton responsible for managing cache and write buffer manager associated with all RocksDB * state store instances running on a single executor if boundedMemoryUsage is enabled for RocksDB. * If boundedMemoryUsage is disabled, a new cache object is returned. + * This also implements UnmanagedMemoryConsumer to report RocksDB memory usage to Spark's + * UnifiedMemoryManager, allowing Spark to account for RocksDB memory when making + * memory allocation decisions. */ -object RocksDBMemoryManager extends Logging { +object RocksDBMemoryManager extends Logging with UnmanagedMemoryConsumer{ private var writeBufferManager: WriteBufferManager = null private var cache: Cache = null + // Tracks memory usage and bounded memory mode per unique ID + private case class InstanceMemoryInfo(memoryUsage: Long, isBoundedMemory: Boolean) + private val instanceMemoryMap = new ConcurrentHashMap[String, InstanceMemoryInfo]() + + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = { + UnmanagedMemoryConsumerId("RocksDB", "RocksDB-Memory-Manager") + } + + override def memoryMode: MemoryMode = { + // RocksDB uses native/off-heap memory for its data structures + MemoryMode.OFF_HEAP + } + + override def getMemBytesUsed: Long = { + val memoryInfos = instanceMemoryMap.values().asScala.toSeq + if (memoryInfos.isEmpty) { + return 0L + } + + // Separate instances by bounded vs unbounded memory mode + val (bounded, unbounded) = memoryInfos.partition(_.isBoundedMemory) + + // For bounded memory instances, they all share the same memory pool, + // so just take the max value (they should all be similar) + val boundedMemory = if (bounded.nonEmpty) bounded.map(_.memoryUsage).max else 0L + + // For unbounded memory instances, sum their individual usages + val unboundedMemory = unbounded.map(_.memoryUsage).sum + + // Total is bounded memory (shared) + sum of unbounded memory (individual) + boundedMemory + unboundedMemory + } + + /** + * Register/update a RocksDB instance with its memory usage. + * @param uniqueId The instance's unique identifier + * @param memoryUsage The current memory usage in bytes + * @param isBoundedMemory Whether this instance uses bounded memory mode + */ + def updateMemoryUsage( + uniqueId: String, + memoryUsage: Long, + isBoundedMemory: Boolean): Unit = { + instanceMemoryMap.put(uniqueId, InstanceMemoryInfo(memoryUsage, isBoundedMemory)) + logDebug(s"Updated memory usage for $uniqueId: $memoryUsage bytes " + + s"(bounded=$isBoundedMemory)") + } + + /** + * Unregister a RocksDB instance. + * @param uniqueId The instance's unique identifier + */ + def unregisterInstance(uniqueId: String): Unit = { + instanceMemoryMap.remove(uniqueId) + logDebug(s"Unregistered instance $uniqueId") + } + def getOrCreateRocksDBMemoryManagerAndCache(conf: RocksDBConf): (WriteBufferManager, Cache) = synchronized { + // Register with UnifiedMemoryManager (idempotent operation) + if (SparkEnv.get != null) { + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(this) + } + if (conf.boundedMemoryUsage) { if (writeBufferManager == null) { assert(cache == null) @@ -72,5 +143,6 @@ object RocksDBMemoryManager extends Logging { def resetWriteBufferManagerAndCache: Unit = synchronized { writeBufferManager = null cache = null + instanceMemoryMap.clear() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index bfcb2cdda296..cc21556fdd14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -556,6 +556,7 @@ private[sql] class RocksDBStateStoreProvider this.rocksDBEventForwarder = Some(RocksDBEventForwarder(StateStoreProvider.getRunId(hadoopConf), stateStoreId)) + // Initialize StateStoreProviderId for memory tracking val queryRunId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) this.stateStoreProviderId = StateStoreProviderId(stateStoreId, queryRunId) @@ -769,7 +770,8 @@ private[sql] class RocksDBStateStoreProvider useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, partitionId: Int = 0, - eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { + eventForwarder: Option[RocksDBEventForwarder] = None, + uniqueId: String = ""): RocksDB = { new RocksDB( dfsRootDir, conf, @@ -779,7 +781,8 @@ private[sql] class RocksDBStateStoreProvider useColumnFamilies, enableStateStoreCheckpointIds, partitionId, - eventForwarder) + eventForwarder, + uniqueId) } private[sql] lazy val rocksDB = { @@ -791,7 +794,7 @@ private[sql] class RocksDBStateStoreProvider val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, loggingId, useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId, - rocksDBEventForwarder) + rocksDBEventForwarder, stateStoreProviderId.toString) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala index 9429cd5ef39e..fb698c89ff8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala @@ -258,7 +258,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, partitionId: Int, - eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { + eventForwarder: Option[RocksDBEventForwarder] = None, + uniqueId: String): RocksDB = { FailureInjectionRocksDBStateStoreProvider.createRocksDBWithFaultInjection( dfsRootDir, conf, @@ -268,7 +269,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide useColumnFamilies, enableStateStoreCheckpointIds, partitionId, - eventForwarder) + eventForwarder, + uniqueId) } } @@ -286,7 +288,8 @@ object FailureInjectionRocksDBStateStoreProvider { useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, partitionId: Int, - eventForwarder: Option[RocksDBEventForwarder]): RocksDB = { + eventForwarder: Option[RocksDBEventForwarder], + uniqueId: String): RocksDB = { new RocksDB( dfsRootDir, conf = conf, @@ -296,7 +299,8 @@ object FailureInjectionRocksDBStateStoreProvider { useColumnFamilies = useColumnFamilies, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, partitionId = partitionId, - eventForwarder = eventForwarder + eventForwarder = eventForwarder, + uniqueId ) { override def createFileManager( dfsRootDir: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala index 0c3e457c8df1..4018971d20f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala @@ -602,7 +602,8 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest useColumnFamilies = true, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, partitionId = 0, - eventForwarder = None) + eventForwarder = None, + uniqueId = "") db.load(version, checkpointId) func(db) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index e0af281fecb9..d2c95dfe5016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters.SetHasAsScala import org.scalatest.time.{Minute, Span} import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} -import org.apache.spark.sql.functions.count +import org.apache.spark.sql.functions.{count, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.OutputMode.Update @@ -314,4 +314,80 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest assert(changelogVersionsPresent(dirForPartition0) == List(3L, 4L)) assert(snapshotVersionsPresent(dirForPartition0).contains(5L)) } + + // Test with both bounded memory enabled and disabled + Seq(true, false).foreach { boundedMemoryEnabled => + test(s"RocksDB memory tracking integration with UnifiedMemoryManager" + + s" with boundedMemory=$boundedMemoryEnabled") { + withTempDir { dir => + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName), + (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath), + (SQLConf.SHUFFLE_PARTITIONS.key -> "5"), + (SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> (5 * 60 * 1000).toString), + ("spark.memory.unmanagedMemoryPollingInterval" -> "100ms"), + ("spark.sql.streaming.stateStore.rocksdb.boundedMemoryUsage" -> + boundedMemoryEnabled.toString)) { + + import org.apache.spark.memory.UnifiedMemoryManager + import org.apache.spark.sql.streaming.Trigger + + // Use rate stream to ensure continuous state operations that trigger memory updates + val query = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") // Continuous but not overwhelming + .load() + .selectExpr("value % 100 as key", "value") + .groupBy("key") + .agg(count("*").as("count"), max("value").as("max_value")) + .writeStream + .format("console") + .outputMode("update") + .trigger(Trigger.ProcessingTime(200)) // Regular triggers to ensure state operations + .start() + + try { + // Let the stream run to establish RocksDB instances and generate state operations + Thread.sleep(2000) // 2 seconds should be enough for several processing cycles + + // Now check for memory tracking - the continuous stream should trigger memory updates + var rocksDBMemory = 0L + var attempts = 0 + val maxAttempts = 15 // 15 attempts with 1-second intervals = 15 seconds max + + while (rocksDBMemory <= 0L && attempts < maxAttempts) { + Thread.sleep(1000) // Wait between checks to allow memory updates + rocksDBMemory = UnifiedMemoryManager.getMemoryByComponentType("RocksDB") + attempts += 1 + + if (rocksDBMemory > 0L) { + logInfo(s"RocksDB memory detected: $rocksDBMemory bytes " + + s"after $attempts attempts with boundedMemory=$boundedMemoryEnabled") + } + } + + // Verify memory tracking remains stable during continued operation + Thread.sleep(2000) // Let stream continue running + + val finalMemory = UnifiedMemoryManager.getMemoryByComponentType("RocksDB") + + // Memory should still be tracked (allow for some fluctuation) + assert(finalMemory > 0L, + s"RocksDB memory tracking should remain active during stream processing: " + + s"got $finalMemory bytes (initial: $rocksDBMemory) " + + s"with boundedMemory=$boundedMemoryEnabled") + + logInfo(s"RocksDB memory tracking test completed successfully: " + + s"initial=$rocksDBMemory bytes, final=$finalMemory bytes " + + s"with boundedMemory=$boundedMemoryEnabled") + + } finally { + query.stop() + // Clean up unmanaged memory users + UnifiedMemoryManager.clearUnmanagedMemoryUsers() + } + } + } + } + } }