Skip to content

Commit 9320aed

Browse files
authored
feat: Add a spark.comet.exec.memoryPool configuration for experimenting with various datafusion memory pool setups. (#1021)
1 parent 4f8ce75 commit 9320aed

File tree

5 files changed

+231
-22
lines changed

5 files changed

+231
-22
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,15 @@ object CometConf extends ShimCometConf {
467467
.booleanConf
468468
.createWithDefault(false)
469469

470+
val COMET_EXEC_MEMORY_POOL_TYPE: ConfigEntry[String] = conf("spark.comet.exec.memoryPool")
471+
.doc(
472+
"The type of memory pool to be used for Comet native execution. " +
473+
"Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " +
474+
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " +
475+
"this config is 'greedy_task_shared'.")
476+
.stringConf
477+
.createWithDefault("greedy_task_shared")
478+
470479
val COMET_SCAN_PREFETCH_ENABLED: ConfigEntry[Boolean] =
471480
conf("spark.comet.scan.preFetch.enabled")
472481
.doc("Whether to enable pre-fetching feature of CometScan.")

docs/source/user-guide/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Comet provides the following configuration settings.
4848
| spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true |
4949
| spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true |
5050
| spark.comet.exec.memoryFraction | The fraction of memory from Comet memory overhead that the native memory manager can use for execution. The purpose of this config is to set aside memory for untracked data structures, as well as imprecise size estimation during memory acquisition. | 0.7 |
51+
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared |
5152
| spark.comet.exec.project.enabled | Whether to enable project by default. | true |
5253
| spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
5354
| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. Compression can be disabled by setting spark.shuffle.compress=false. | zstd |

native/core/src/execution/jni_api.rs

Lines changed: 190 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ use datafusion::{
2424
physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream},
2525
prelude::{SessionConfig, SessionContext},
2626
};
27+
use datafusion_execution::memory_pool::{
28+
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool,
29+
};
2730
use futures::poll;
2831
use jni::{
2932
errors::Result as JNIResult,
@@ -51,20 +54,26 @@ use datafusion_comet_proto::spark_operator::Operator;
5154
use datafusion_common::ScalarValue;
5255
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
5356
use futures::stream::StreamExt;
57+
use jni::sys::JNI_FALSE;
5458
use jni::{
5559
objects::GlobalRef,
5660
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
5761
};
62+
use std::num::NonZeroUsize;
63+
use std::sync::Mutex;
5864
use tokio::runtime::Runtime;
5965

6066
use crate::execution::operators::ScanExec;
6167
use crate::execution::spark_plan::SparkPlan;
6268
use log::info;
69+
use once_cell::sync::{Lazy, OnceCell};
6370

6471
/// Comet native execution context. Kept alive across JNI calls.
6572
struct ExecutionContext {
6673
/// The id of the execution context.
6774
pub id: i64,
75+
/// Task attempt id
76+
pub task_attempt_id: i64,
6877
/// The deserialized Spark plan
6978
pub spark_plan: Operator,
7079
/// The DataFusion root operator converted from the `spark_plan`
@@ -89,6 +98,51 @@ struct ExecutionContext {
8998
pub explain_native: bool,
9099
/// Map of metrics name -> jstring object to cache jni_NewStringUTF calls.
91100
pub metrics_jstrings: HashMap<String, Arc<GlobalRef>>,
101+
/// Memory pool config
102+
pub memory_pool_config: MemoryPoolConfig,
103+
}
104+
105+
#[derive(PartialEq, Eq)]
106+
enum MemoryPoolType {
107+
Unified,
108+
Greedy,
109+
FairSpill,
110+
GreedyTaskShared,
111+
FairSpillTaskShared,
112+
GreedyGlobal,
113+
FairSpillGlobal,
114+
}
115+
116+
struct MemoryPoolConfig {
117+
pool_type: MemoryPoolType,
118+
pool_size: usize,
119+
}
120+
121+
impl MemoryPoolConfig {
122+
fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self {
123+
Self {
124+
pool_type,
125+
pool_size,
126+
}
127+
}
128+
}
129+
130+
/// The per-task memory pools keyed by task attempt id.
131+
static TASK_SHARED_MEMORY_POOLS: Lazy<Mutex<HashMap<i64, PerTaskMemoryPool>>> =
132+
Lazy::new(|| Mutex::new(HashMap::new()));
133+
134+
struct PerTaskMemoryPool {
135+
memory_pool: Arc<dyn MemoryPool>,
136+
num_plans: usize,
137+
}
138+
139+
impl PerTaskMemoryPool {
140+
fn new(memory_pool: Arc<dyn MemoryPool>) -> Self {
141+
Self {
142+
memory_pool,
143+
num_plans: 0,
144+
}
145+
}
92146
}
93147

94148
/// Accept serialized query plan and return the address of the native query plan.
@@ -105,8 +159,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
105159
comet_task_memory_manager_obj: JObject,
106160
batch_size: jint,
107161
use_unified_memory_manager: jboolean,
162+
memory_pool_type: jstring,
108163
memory_limit: jlong,
164+
memory_limit_per_task: jlong,
109165
memory_fraction: jdouble,
166+
task_attempt_id: jlong,
110167
debug_native: jboolean,
111168
explain_native: jboolean,
112169
worker_threads: jint,
@@ -145,21 +202,27 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
145202
let task_memory_manager =
146203
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);
147204

205+
let memory_pool_type = env.get_string(&JString::from_raw(memory_pool_type))?.into();
206+
let memory_pool_config = parse_memory_pool_config(
207+
use_unified_memory_manager != JNI_FALSE,
208+
memory_pool_type,
209+
memory_limit,
210+
memory_limit_per_task,
211+
memory_fraction,
212+
)?;
213+
let memory_pool =
214+
create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id);
215+
148216
// We need to keep the session context alive. Some session state like temporary
149217
// dictionaries are stored in session context. If it is dropped, the temporary
150218
// dictionaries will be dropped as well.
151-
let session = prepare_datafusion_session_context(
152-
batch_size as usize,
153-
use_unified_memory_manager == 1,
154-
memory_limit as usize,
155-
memory_fraction,
156-
task_memory_manager,
157-
)?;
219+
let session = prepare_datafusion_session_context(batch_size as usize, memory_pool)?;
158220

159221
let plan_creation_time = start.elapsed();
160222

161223
let exec_context = Box::new(ExecutionContext {
162224
id,
225+
task_attempt_id,
163226
spark_plan,
164227
root_op: None,
165228
scans: vec![],
@@ -172,6 +235,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
172235
debug_native: debug_native == 1,
173236
explain_native: explain_native == 1,
174237
metrics_jstrings: HashMap::new(),
238+
memory_pool_config,
175239
});
176240

177241
Ok(Box::into_raw(exec_context) as i64)
@@ -181,22 +245,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
181245
/// Configure DataFusion session context.
182246
fn prepare_datafusion_session_context(
183247
batch_size: usize,
184-
use_unified_memory_manager: bool,
185-
memory_limit: usize,
186-
memory_fraction: f64,
187-
comet_task_memory_manager: Arc<GlobalRef>,
248+
memory_pool: Arc<dyn MemoryPool>,
188249
) -> CometResult<SessionContext> {
189250
let mut rt_config = RuntimeEnvBuilder::new().with_disk_manager(DiskManagerConfig::NewOs);
190-
191-
// Check if we are using unified memory manager integrated with Spark.
192-
if use_unified_memory_manager {
193-
// Set Comet memory pool for native
194-
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
195-
rt_config = rt_config.with_memory_pool(Arc::new(memory_pool));
196-
} else {
197-
// Use the memory pool from DF
198-
rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction)
199-
}
251+
rt_config = rt_config.with_memory_pool(memory_pool);
200252

201253
// Get Datafusion configuration from Spark Execution context
202254
// can be configured in Comet Spark JVM using Spark --conf parameters
@@ -224,6 +276,107 @@ fn prepare_datafusion_session_context(
224276
Ok(session_ctx)
225277
}
226278

279+
fn parse_memory_pool_config(
280+
use_unified_memory_manager: bool,
281+
memory_pool_type: String,
282+
memory_limit: i64,
283+
memory_limit_per_task: i64,
284+
memory_fraction: f64,
285+
) -> CometResult<MemoryPoolConfig> {
286+
let memory_pool_config = if use_unified_memory_manager {
287+
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
288+
} else {
289+
// Use the memory pool from DF
290+
let pool_size = (memory_limit as f64 * memory_fraction) as usize;
291+
let pool_size_per_task = (memory_limit_per_task as f64 * memory_fraction) as usize;
292+
match memory_pool_type.as_str() {
293+
"fair_spill_task_shared" => {
294+
MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task)
295+
}
296+
"greedy_task_shared" => {
297+
MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, pool_size_per_task)
298+
}
299+
"fair_spill_global" => {
300+
MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size)
301+
}
302+
"greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size),
303+
"fair_spill" => MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task),
304+
"greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task),
305+
_ => {
306+
return Err(CometError::Config(format!(
307+
"Unsupported memory pool type: {}",
308+
memory_pool_type
309+
)))
310+
}
311+
}
312+
};
313+
Ok(memory_pool_config)
314+
}
315+
316+
fn create_memory_pool(
317+
memory_pool_config: &MemoryPoolConfig,
318+
comet_task_memory_manager: Arc<GlobalRef>,
319+
task_attempt_id: i64,
320+
) -> Arc<dyn MemoryPool> {
321+
const NUM_TRACKED_CONSUMERS: usize = 10;
322+
match memory_pool_config.pool_type {
323+
MemoryPoolType::Unified => {
324+
// Set Comet memory pool for native
325+
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
326+
Arc::new(memory_pool)
327+
}
328+
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
329+
GreedyMemoryPool::new(memory_pool_config.pool_size),
330+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
331+
)),
332+
MemoryPoolType::FairSpill => Arc::new(TrackConsumersPool::new(
333+
FairSpillPool::new(memory_pool_config.pool_size),
334+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
335+
)),
336+
MemoryPoolType::GreedyGlobal => {
337+
static GLOBAL_MEMORY_POOL_GREEDY: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
338+
let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| {
339+
Arc::new(TrackConsumersPool::new(
340+
GreedyMemoryPool::new(memory_pool_config.pool_size),
341+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
342+
))
343+
});
344+
Arc::clone(memory_pool)
345+
}
346+
MemoryPoolType::FairSpillGlobal => {
347+
static GLOBAL_MEMORY_POOL_FAIR: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
348+
let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| {
349+
Arc::new(TrackConsumersPool::new(
350+
FairSpillPool::new(memory_pool_config.pool_size),
351+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
352+
))
353+
});
354+
Arc::clone(memory_pool)
355+
}
356+
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared => {
357+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
358+
let per_task_memory_pool =
359+
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
360+
let pool: Arc<dyn MemoryPool> =
361+
if memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared {
362+
Arc::new(TrackConsumersPool::new(
363+
GreedyMemoryPool::new(memory_pool_config.pool_size),
364+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
365+
))
366+
} else {
367+
Arc::new(TrackConsumersPool::new(
368+
FairSpillPool::new(memory_pool_config.pool_size),
369+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
370+
))
371+
};
372+
PerTaskMemoryPool::new(pool)
373+
});
374+
per_task_memory_pool.num_plans += 1;
375+
Arc::clone(&per_task_memory_pool.memory_pool)
376+
}
377+
}
378+
}
379+
227380
/// Prepares arrow arrays for output.
228381
fn prepare_output(
229382
env: &mut JNIEnv,
@@ -407,6 +560,22 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
407560
) {
408561
try_unwrap_or_throw(&e, |_| unsafe {
409562
let execution_context = get_execution_context(exec_context);
563+
if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared
564+
|| execution_context.memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared
565+
{
566+
// Decrement the number of native plans using the per-task shared memory pool, and
567+
// remove the memory pool if the released native plan is the last native plan using it.
568+
let task_attempt_id = execution_context.task_attempt_id;
569+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
570+
if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) {
571+
per_task_memory_pool.num_plans -= 1;
572+
if per_task_memory_pool.num_plans == 0 {
573+
// Drop the memory pool from the per-task memory pool map if there are no
574+
// more native plans using it.
575+
memory_pool_map.remove(&task_attempt_id);
576+
}
577+
}
578+
}
410579
let _: Box<ExecutionContext> = Box::from_raw(execution_context);
411580
Ok(())
412581
})

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark._
2323
import org.apache.spark.sql.comet.CometMetricNode
2424
import org.apache.spark.sql.vectorized._
2525

26-
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS}
26+
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS}
2727
import org.apache.comet.vector.NativeUtil
2828

2929
/**
@@ -72,8 +72,11 @@ class CometExecIterator(
7272
new CometTaskMemoryManager(id),
7373
batchSize = COMET_BATCH_SIZE.get(),
7474
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
75+
memory_pool_type = COMET_EXEC_MEMORY_POOL_TYPE.get(),
7576
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
77+
memory_limit_per_task = getMemoryLimitPerTask(conf),
7678
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
79+
task_attempt_id = TaskContext.get().taskAttemptId,
7780
debug = COMET_DEBUG_ENABLED.get(),
7881
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
7982
workerThreads = COMET_WORKER_THREADS.get(),
@@ -84,6 +87,30 @@ class CometExecIterator(
8487
private var currentBatch: ColumnarBatch = null
8588
private var closed: Boolean = false
8689

90+
private def getMemoryLimitPerTask(conf: SparkConf): Long = {
91+
val numCores = numDriverOrExecutorCores(conf).toFloat
92+
val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
93+
val coresPerTask = conf.get("spark.task.cpus", "1").toFloat
94+
// example 16GB maxMemory * 16 cores with 4 cores per task results
95+
// in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
96+
(maxMemory.toFloat * coresPerTask / numCores).toLong
97+
}
98+
99+
private def numDriverOrExecutorCores(conf: SparkConf): Int = {
100+
def convertToInt(threads: String): Int = {
101+
if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt
102+
}
103+
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
104+
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
105+
val master = conf.get("spark.master")
106+
master match {
107+
case "local" => 1
108+
case LOCAL_N_REGEX(threads) => convertToInt(threads)
109+
case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
110+
case _ => conf.get("spark.executor.cores", "1").toInt
111+
}
112+
}
113+
87114
def getNextBatch(): Option[ColumnarBatch] = {
88115
assert(partitionIndex >= 0 && partitionIndex < numParts)
89116

spark/src/main/scala/org/apache/comet/Native.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ class Native extends NativeBase {
5252
taskMemoryManager: CometTaskMemoryManager,
5353
batchSize: Int,
5454
use_unified_memory_manager: Boolean,
55+
memory_pool_type: String,
5556
memory_limit: Long,
57+
memory_limit_per_task: Long,
5658
memory_fraction: Double,
59+
task_attempt_id: Long,
5760
debug: Boolean,
5861
explain: Boolean,
5962
workerThreads: Int,

0 commit comments

Comments
 (0)