Skip to content

Commit 1fe9ab6

Browse files
committed
Make datafusion's native memory pool configurable
1 parent 58dee73 commit 1fe9ab6

File tree

5 files changed

+229
-22
lines changed

5 files changed

+229
-22
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,15 @@ object CometConf extends ShimCometConf {
467467
.booleanConf
468468
.createWithDefault(false)
469469

470+
val COMET_EXEC_MEMORY_POOL_TYPE: ConfigEntry[String] = conf("spark.comet.exec.memoryPool")
471+
.doc(
472+
"The type of memory pool to be used for Comet native execution. " +
473+
"Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " +
474+
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " +
475+
"this config is 'greedy_task_shared'.")
476+
.stringConf
477+
.createWithDefault("greedy_task_shared")
478+
470479
val COMET_SCAN_PREFETCH_ENABLED: ConfigEntry[Boolean] =
471480
conf("spark.comet.scan.preFetch.enabled")
472481
.doc("Whether to enable pre-fetching feature of CometScan.")

docs/source/user-guide/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Comet provides the following configuration settings.
4848
| spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true |
4949
| spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true |
5050
| spark.comet.exec.memoryFraction | The fraction of memory from Comet memory overhead that the native memory manager can use for execution. The purpose of this config is to set aside memory for untracked data structures, as well as imprecise size estimation during memory acquisition. | 0.7 |
51+
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared |
5152
| spark.comet.exec.project.enabled | Whether to enable project by default. | true |
5253
| spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
5354
| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. Compression can be disabled by setting spark.shuffle.compress=false. | zstd |

native/core/src/execution/jni_api.rs

Lines changed: 188 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ use datafusion::{
2727
physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream},
2828
prelude::{SessionConfig, SessionContext},
2929
};
30+
use datafusion_execution::memory_pool::{
31+
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool,
32+
};
3033
use futures::poll;
3134
use jni::{
3235
errors::Result as JNIResult,
@@ -53,20 +56,26 @@ use crate::{
5356
use datafusion_comet_proto::spark_operator::Operator;
5457
use datafusion_common::ScalarValue;
5558
use futures::stream::StreamExt;
59+
use jni::sys::JNI_FALSE;
5660
use jni::{
5761
objects::GlobalRef,
5862
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
5963
};
64+
use std::num::NonZeroUsize;
65+
use std::sync::Mutex;
6066
use tokio::runtime::Runtime;
6167

6268
use crate::execution::operators::ScanExec;
6369
use crate::execution::spark_plan::SparkPlan;
6470
use log::info;
71+
use once_cell::sync::{Lazy, OnceCell};
6572

6673
/// Comet native execution context. Kept alive across JNI calls.
6774
struct ExecutionContext {
6875
/// The id of the execution context.
6976
pub id: i64,
77+
/// Task attempt id
78+
pub task_attempt_id: i64,
7079
/// The deserialized Spark plan
7180
pub spark_plan: Operator,
7281
/// The DataFusion root operator converted from the `spark_plan`
@@ -91,6 +100,51 @@ struct ExecutionContext {
91100
pub explain_native: bool,
92101
/// Map of metrics name -> jstring object to cache jni_NewStringUTF calls.
93102
pub metrics_jstrings: HashMap<String, Arc<GlobalRef>>,
103+
/// Memory pool config
104+
pub memory_pool_config: MemoryPoolConfig,
105+
}
106+
107+
#[derive(PartialEq, Eq)]
108+
enum MemoryPoolType {
109+
Unified,
110+
Greedy,
111+
FairSpill,
112+
GreedyTaskShared,
113+
FairSpillTaskShared,
114+
GreedyGlobal,
115+
FairSpillGlobal,
116+
}
117+
118+
struct MemoryPoolConfig {
119+
pool_type: MemoryPoolType,
120+
pool_size: usize,
121+
}
122+
123+
impl MemoryPoolConfig {
124+
fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self {
125+
Self {
126+
pool_type,
127+
pool_size,
128+
}
129+
}
130+
}
131+
132+
/// The per-task memory pools keyed by task attempt id.
133+
static TASK_SHARED_MEMORY_POOLS: Lazy<Mutex<HashMap<i64, PerTaskMemoryPool>>> =
134+
Lazy::new(|| Mutex::new(HashMap::new()));
135+
136+
struct PerTaskMemoryPool {
137+
memory_pool: Arc<dyn MemoryPool>,
138+
num_plans: usize,
139+
}
140+
141+
impl PerTaskMemoryPool {
142+
fn new(memory_pool: Arc<dyn MemoryPool>) -> Self {
143+
Self {
144+
memory_pool,
145+
num_plans: 0,
146+
}
147+
}
94148
}
95149

96150
/// Accept serialized query plan and return the address of the native query plan.
@@ -107,8 +161,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
107161
comet_task_memory_manager_obj: JObject,
108162
batch_size: jint,
109163
use_unified_memory_manager: jboolean,
164+
memory_pool_type: jstring,
110165
memory_limit: jlong,
166+
memory_limit_per_task: jlong,
111167
memory_fraction: jdouble,
168+
task_attempt_id: jlong,
112169
debug_native: jboolean,
113170
explain_native: jboolean,
114171
worker_threads: jint,
@@ -147,21 +204,27 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
147204
let task_memory_manager =
148205
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);
149206

207+
let memory_pool_type = env.get_string(&JString::from_raw(memory_pool_type))?.into();
208+
let memory_pool_config = parse_memory_pool_config(
209+
use_unified_memory_manager != JNI_FALSE,
210+
memory_pool_type,
211+
memory_limit,
212+
memory_limit_per_task,
213+
memory_fraction,
214+
)?;
215+
let memory_pool =
216+
create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id);
217+
150218
// We need to keep the session context alive. Some session state like temporary
151219
// dictionaries are stored in session context. If it is dropped, the temporary
152220
// dictionaries will be dropped as well.
153-
let session = prepare_datafusion_session_context(
154-
batch_size as usize,
155-
use_unified_memory_manager == 1,
156-
memory_limit as usize,
157-
memory_fraction,
158-
task_memory_manager,
159-
)?;
221+
let session = prepare_datafusion_session_context(batch_size as usize, memory_pool)?;
160222

161223
let plan_creation_time = start.elapsed();
162224

163225
let exec_context = Box::new(ExecutionContext {
164226
id,
227+
task_attempt_id,
165228
spark_plan,
166229
root_op: None,
167230
scans: vec![],
@@ -174,6 +237,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
174237
debug_native: debug_native == 1,
175238
explain_native: explain_native == 1,
176239
metrics_jstrings: HashMap::new(),
240+
memory_pool_config,
177241
});
178242

179243
Ok(Box::into_raw(exec_context) as i64)
@@ -183,22 +247,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
183247
/// Configure DataFusion session context.
184248
fn prepare_datafusion_session_context(
185249
batch_size: usize,
186-
use_unified_memory_manager: bool,
187-
memory_limit: usize,
188-
memory_fraction: f64,
189-
comet_task_memory_manager: Arc<GlobalRef>,
250+
memory_pool: Arc<dyn MemoryPool>,
190251
) -> CometResult<SessionContext> {
191252
let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);
192-
193-
// Check if we are using unified memory manager integrated with Spark.
194-
if use_unified_memory_manager {
195-
// Set Comet memory pool for native
196-
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
197-
rt_config = rt_config.with_memory_pool(Arc::new(memory_pool));
198-
} else {
199-
// Use the memory pool from DF
200-
rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction)
201-
}
253+
rt_config = rt_config.with_memory_pool(memory_pool);
202254

203255
// Get Datafusion configuration from Spark Execution context
204256
// can be configured in Comet Spark JVM using Spark --conf parameters
@@ -225,6 +277,107 @@ fn prepare_datafusion_session_context(
225277
Ok(session_ctx)
226278
}
227279

280+
fn parse_memory_pool_config(
281+
use_unified_memory_manager: bool,
282+
memory_pool_type: String,
283+
memory_limit: i64,
284+
memory_limit_per_task: i64,
285+
memory_fraction: f64,
286+
) -> CometResult<MemoryPoolConfig> {
287+
let memory_pool_config = if use_unified_memory_manager {
288+
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
289+
} else {
290+
// Use the memory pool from DF
291+
let pool_size = (memory_limit as f64 * memory_fraction) as usize;
292+
let pool_size_per_task = (memory_limit_per_task as f64 * memory_fraction) as usize;
293+
match memory_pool_type.as_str() {
294+
"fair_spill_task_shared" => {
295+
MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task)
296+
}
297+
"greedy_task_shared" => {
298+
MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, pool_size_per_task)
299+
}
300+
"fair_spill_global" => {
301+
MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size)
302+
}
303+
"greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size),
304+
"fair_spill" => MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task),
305+
"greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task),
306+
_ => {
307+
return Err(CometError::Config(format!(
308+
"Unsupported memory pool type: {}",
309+
memory_pool_type
310+
)))
311+
}
312+
}
313+
};
314+
Ok(memory_pool_config)
315+
}
316+
317+
fn create_memory_pool(
318+
memory_pool_config: &MemoryPoolConfig,
319+
comet_task_memory_manager: Arc<GlobalRef>,
320+
task_attempt_id: i64,
321+
) -> Arc<dyn MemoryPool> {
322+
const NUM_TRACKED_CONSUMERS: usize = 10;
323+
match memory_pool_config.pool_type {
324+
MemoryPoolType::Unified => {
325+
// Set Comet memory pool for native
326+
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
327+
Arc::new(memory_pool)
328+
}
329+
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
330+
GreedyMemoryPool::new(memory_pool_config.pool_size),
331+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
332+
)),
333+
MemoryPoolType::FairSpill => Arc::new(TrackConsumersPool::new(
334+
FairSpillPool::new(memory_pool_config.pool_size),
335+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
336+
)),
337+
MemoryPoolType::GreedyGlobal => {
338+
static GLOBAL_MEMORY_POOL_GREEDY: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
339+
let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| {
340+
Arc::new(TrackConsumersPool::new(
341+
GreedyMemoryPool::new(memory_pool_config.pool_size),
342+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
343+
))
344+
});
345+
Arc::clone(memory_pool)
346+
}
347+
MemoryPoolType::FairSpillGlobal => {
348+
static GLOBAL_MEMORY_POOL_FAIR: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
349+
let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| {
350+
Arc::new(TrackConsumersPool::new(
351+
FairSpillPool::new(memory_pool_config.pool_size),
352+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
353+
))
354+
});
355+
Arc::clone(memory_pool)
356+
}
357+
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared => {
358+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
359+
let per_task_memory_pool =
360+
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
361+
let pool: Arc<dyn MemoryPool> =
362+
if memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared {
363+
Arc::new(TrackConsumersPool::new(
364+
GreedyMemoryPool::new(memory_pool_config.pool_size),
365+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
366+
))
367+
} else {
368+
Arc::new(TrackConsumersPool::new(
369+
FairSpillPool::new(memory_pool_config.pool_size),
370+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
371+
))
372+
};
373+
PerTaskMemoryPool::new(pool)
374+
});
375+
per_task_memory_pool.num_plans += 1;
376+
Arc::clone(&per_task_memory_pool.memory_pool)
377+
}
378+
}
379+
}
380+
228381
/// Prepares arrow arrays for output.
229382
fn prepare_output(
230383
env: &mut JNIEnv,
@@ -408,6 +561,20 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
408561
) {
409562
try_unwrap_or_throw(&e, |_| unsafe {
410563
let execution_context = get_execution_context(exec_context);
564+
if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared {
565+
// Decrement the number of native plans using the per-task shared memory pool, and
566+
// remove the memory pool if the released native plan is the last native plan using it.
567+
let task_attempt_id = execution_context.task_attempt_id;
568+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
569+
if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) {
570+
per_task_memory_pool.num_plans -= 1;
571+
if per_task_memory_pool.num_plans == 0 {
572+
// Drop the memory pool from the per-task memory pool map if there are no
573+
// more native plans using it.
574+
memory_pool_map.remove(&task_attempt_id);
575+
}
576+
}
577+
}
411578
let _: Box<ExecutionContext> = Box::from_raw(execution_context);
412579
Ok(())
413580
})

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark._
2323
import org.apache.spark.sql.comet.CometMetricNode
2424
import org.apache.spark.sql.vectorized._
2525

26-
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS}
26+
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS}
2727
import org.apache.comet.vector.NativeUtil
2828

2929
/**
@@ -72,8 +72,11 @@ class CometExecIterator(
7272
new CometTaskMemoryManager(id),
7373
batchSize = COMET_BATCH_SIZE.get(),
7474
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
75+
memory_pool_type = COMET_EXEC_MEMORY_POOL_TYPE.get(),
7576
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
77+
memory_limit_per_task = getMemoryLimitPerTask(conf),
7678
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
79+
task_attempt_id = TaskContext.get().taskAttemptId,
7780
debug = COMET_DEBUG_ENABLED.get(),
7881
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
7982
workerThreads = COMET_WORKER_THREADS.get(),
@@ -84,6 +87,30 @@ class CometExecIterator(
8487
private var currentBatch: ColumnarBatch = null
8588
private var closed: Boolean = false
8689

90+
private def getMemoryLimitPerTask(conf: SparkConf): Long = {
91+
val numCores = numDriverOrExecutorCores(conf).toFloat
92+
val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
93+
val coresPerTask = conf.get("spark.task.cpus", "1").toFloat
94+
// example 16GB maxMemory * 16 cores with 4 cores per task results
95+
// in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
96+
(maxMemory.toFloat * coresPerTask / numCores).toLong
97+
}
98+
99+
private def numDriverOrExecutorCores(conf: SparkConf): Int = {
100+
def convertToInt(threads: String): Int = {
101+
if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt
102+
}
103+
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
104+
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
105+
val master = conf.get("spark.master")
106+
master match {
107+
case "local" => 1
108+
case LOCAL_N_REGEX(threads) => convertToInt(threads)
109+
case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
110+
case _ => conf.get("spark.executor.cores", "1").toInt
111+
}
112+
}
113+
87114
def getNextBatch(): Option[ColumnarBatch] = {
88115
assert(partitionIndex >= 0 && partitionIndex < numParts)
89116

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ class Native extends NativeBase {
5252
taskMemoryManager: CometTaskMemoryManager,
5353
batchSize: Int,
5454
use_unified_memory_manager: Boolean,
55+
memory_pool_type: String,
5556
memory_limit: Long,
57+
memory_limit_per_task: Long,
5658
memory_fraction: Double,
59+
task_attempt_id: Long,
5760
debug: Boolean,
5861
explain: Boolean,
5962
workerThreads: Int,

0 commit comments

Comments
 (0)