@@ -24,6 +24,9 @@ use datafusion::{
2424 physical_plan:: { display:: DisplayableExecutionPlan , SendableRecordBatchStream } ,
2525 prelude:: { SessionConfig , SessionContext } ,
2626} ;
27+ use datafusion_execution:: memory_pool:: {
28+ FairSpillPool , GreedyMemoryPool , MemoryPool , TrackConsumersPool ,
29+ } ;
2730use futures:: poll;
2831use jni:: {
2932 errors:: Result as JNIResult ,
@@ -51,20 +54,26 @@ use datafusion_comet_proto::spark_operator::Operator;
5154use datafusion_common:: ScalarValue ;
5255use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
5356use futures:: stream:: StreamExt ;
57+ use jni:: sys:: JNI_FALSE ;
5458use jni:: {
5559 objects:: GlobalRef ,
5660 sys:: { jboolean, jdouble, jintArray, jobjectArray, jstring} ,
5761} ;
62+ use std:: num:: NonZeroUsize ;
63+ use std:: sync:: Mutex ;
5864use tokio:: runtime:: Runtime ;
5965
6066use crate :: execution:: operators:: ScanExec ;
6167use crate :: execution:: spark_plan:: SparkPlan ;
6268use log:: info;
69+ use once_cell:: sync:: { Lazy , OnceCell } ;
6370
6471/// Comet native execution context. Kept alive across JNI calls.
6572struct ExecutionContext {
6673 /// The id of the execution context.
6774 pub id : i64 ,
75+ /// Task attempt id
76+ pub task_attempt_id : i64 ,
6877 /// The deserialized Spark plan
6978 pub spark_plan : Operator ,
7079 /// The DataFusion root operator converted from the `spark_plan`
@@ -89,6 +98,51 @@ struct ExecutionContext {
8998 pub explain_native : bool ,
9099 /// Map of metrics name -> jstring object to cache jni_NewStringUTF calls.
91100 pub metrics_jstrings : HashMap < String , Arc < GlobalRef > > ,
101+ /// Memory pool config
102+ pub memory_pool_config : MemoryPoolConfig ,
103+ }
104+
105+ #[ derive( PartialEq , Eq ) ]
106+ enum MemoryPoolType {
107+ Unified ,
108+ Greedy ,
109+ FairSpill ,
110+ GreedyTaskShared ,
111+ FairSpillTaskShared ,
112+ GreedyGlobal ,
113+ FairSpillGlobal ,
114+ }
115+
116+ struct MemoryPoolConfig {
117+ pool_type : MemoryPoolType ,
118+ pool_size : usize ,
119+ }
120+
121+ impl MemoryPoolConfig {
122+ fn new ( pool_type : MemoryPoolType , pool_size : usize ) -> Self {
123+ Self {
124+ pool_type,
125+ pool_size,
126+ }
127+ }
128+ }
129+
130+ /// The per-task memory pools keyed by task attempt id.
131+ static TASK_SHARED_MEMORY_POOLS : Lazy < Mutex < HashMap < i64 , PerTaskMemoryPool > > > =
132+ Lazy :: new ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
133+
134+ struct PerTaskMemoryPool {
135+ memory_pool : Arc < dyn MemoryPool > ,
136+ num_plans : usize ,
137+ }
138+
139+ impl PerTaskMemoryPool {
140+ fn new ( memory_pool : Arc < dyn MemoryPool > ) -> Self {
141+ Self {
142+ memory_pool,
143+ num_plans : 0 ,
144+ }
145+ }
92146}
93147
94148/// Accept serialized query plan and return the address of the native query plan.
@@ -105,8 +159,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
105159 comet_task_memory_manager_obj : JObject ,
106160 batch_size : jint ,
107161 use_unified_memory_manager : jboolean ,
162+ memory_pool_type : jstring ,
108163 memory_limit : jlong ,
164+ memory_limit_per_task : jlong ,
109165 memory_fraction : jdouble ,
166+ task_attempt_id : jlong ,
110167 debug_native : jboolean ,
111168 explain_native : jboolean ,
112169 worker_threads : jint ,
@@ -145,21 +202,27 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
145202 let task_memory_manager =
146203 Arc :: new ( jni_new_global_ref ! ( env, comet_task_memory_manager_obj) ?) ;
147204
205+ let memory_pool_type = env. get_string ( & JString :: from_raw ( memory_pool_type) ) ?. into ( ) ;
206+ let memory_pool_config = parse_memory_pool_config (
207+ use_unified_memory_manager != JNI_FALSE ,
208+ memory_pool_type,
209+ memory_limit,
210+ memory_limit_per_task,
211+ memory_fraction,
212+ ) ?;
213+ let memory_pool =
214+ create_memory_pool ( & memory_pool_config, task_memory_manager, task_attempt_id) ;
215+
148216 // We need to keep the session context alive. Some session state like temporary
149217 // dictionaries are stored in session context. If it is dropped, the temporary
150218 // dictionaries will be dropped as well.
151- let session = prepare_datafusion_session_context (
152- batch_size as usize ,
153- use_unified_memory_manager == 1 ,
154- memory_limit as usize ,
155- memory_fraction,
156- task_memory_manager,
157- ) ?;
219+ let session = prepare_datafusion_session_context ( batch_size as usize , memory_pool) ?;
158220
159221 let plan_creation_time = start. elapsed ( ) ;
160222
161223 let exec_context = Box :: new ( ExecutionContext {
162224 id,
225+ task_attempt_id,
163226 spark_plan,
164227 root_op : None ,
165228 scans : vec ! [ ] ,
@@ -172,6 +235,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
172235 debug_native : debug_native == 1 ,
173236 explain_native : explain_native == 1 ,
174237 metrics_jstrings : HashMap :: new ( ) ,
238+ memory_pool_config,
175239 } ) ;
176240
177241 Ok ( Box :: into_raw ( exec_context) as i64 )
@@ -181,22 +245,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
181245/// Configure DataFusion session context.
182246fn prepare_datafusion_session_context (
183247 batch_size : usize ,
184- use_unified_memory_manager : bool ,
185- memory_limit : usize ,
186- memory_fraction : f64 ,
187- comet_task_memory_manager : Arc < GlobalRef > ,
248+ memory_pool : Arc < dyn MemoryPool > ,
188249) -> CometResult < SessionContext > {
189250 let mut rt_config = RuntimeEnvBuilder :: new ( ) . with_disk_manager ( DiskManagerConfig :: NewOs ) ;
190-
191- // Check if we are using unified memory manager integrated with Spark.
192- if use_unified_memory_manager {
193- // Set Comet memory pool for native
194- let memory_pool = CometMemoryPool :: new ( comet_task_memory_manager) ;
195- rt_config = rt_config. with_memory_pool ( Arc :: new ( memory_pool) ) ;
196- } else {
197- // Use the memory pool from DF
198- rt_config = rt_config. with_memory_limit ( memory_limit, memory_fraction)
199- }
251+ rt_config = rt_config. with_memory_pool ( memory_pool) ;
200252
201253 // Get Datafusion configuration from Spark Execution context
202254 // can be configured in Comet Spark JVM using Spark --conf parameters
@@ -224,6 +276,107 @@ fn prepare_datafusion_session_context(
224276 Ok ( session_ctx)
225277}
226278
279+ fn parse_memory_pool_config (
280+ use_unified_memory_manager : bool ,
281+ memory_pool_type : String ,
282+ memory_limit : i64 ,
283+ memory_limit_per_task : i64 ,
284+ memory_fraction : f64 ,
285+ ) -> CometResult < MemoryPoolConfig > {
286+ let memory_pool_config = if use_unified_memory_manager {
287+ MemoryPoolConfig :: new ( MemoryPoolType :: Unified , 0 )
288+ } else {
289+ // Use the memory pool from DF
290+ let pool_size = ( memory_limit as f64 * memory_fraction) as usize ;
291+ let pool_size_per_task = ( memory_limit_per_task as f64 * memory_fraction) as usize ;
292+ match memory_pool_type. as_str ( ) {
293+ "fair_spill_task_shared" => {
294+ MemoryPoolConfig :: new ( MemoryPoolType :: FairSpillTaskShared , pool_size_per_task)
295+ }
296+ "greedy_task_shared" => {
297+ MemoryPoolConfig :: new ( MemoryPoolType :: GreedyTaskShared , pool_size_per_task)
298+ }
299+ "fair_spill_global" => {
300+ MemoryPoolConfig :: new ( MemoryPoolType :: FairSpillGlobal , pool_size)
301+ }
302+ "greedy_global" => MemoryPoolConfig :: new ( MemoryPoolType :: GreedyGlobal , pool_size) ,
303+ "fair_spill" => MemoryPoolConfig :: new ( MemoryPoolType :: FairSpill , pool_size_per_task) ,
304+ "greedy" => MemoryPoolConfig :: new ( MemoryPoolType :: Greedy , pool_size_per_task) ,
305+ _ => {
306+ return Err ( CometError :: Config ( format ! (
307+ "Unsupported memory pool type: {}" ,
308+ memory_pool_type
309+ ) ) )
310+ }
311+ }
312+ } ;
313+ Ok ( memory_pool_config)
314+ }
315+
316+ fn create_memory_pool (
317+ memory_pool_config : & MemoryPoolConfig ,
318+ comet_task_memory_manager : Arc < GlobalRef > ,
319+ task_attempt_id : i64 ,
320+ ) -> Arc < dyn MemoryPool > {
321+ const NUM_TRACKED_CONSUMERS : usize = 10 ;
322+ match memory_pool_config. pool_type {
323+ MemoryPoolType :: Unified => {
324+ // Set Comet memory pool for native
325+ let memory_pool = CometMemoryPool :: new ( comet_task_memory_manager) ;
326+ Arc :: new ( memory_pool)
327+ }
328+ MemoryPoolType :: Greedy => Arc :: new ( TrackConsumersPool :: new (
329+ GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
330+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
331+ ) ) ,
332+ MemoryPoolType :: FairSpill => Arc :: new ( TrackConsumersPool :: new (
333+ FairSpillPool :: new ( memory_pool_config. pool_size ) ,
334+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
335+ ) ) ,
336+ MemoryPoolType :: GreedyGlobal => {
337+ static GLOBAL_MEMORY_POOL_GREEDY : OnceCell < Arc < dyn MemoryPool > > = OnceCell :: new ( ) ;
338+ let memory_pool = GLOBAL_MEMORY_POOL_GREEDY . get_or_init ( || {
339+ Arc :: new ( TrackConsumersPool :: new (
340+ GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
341+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
342+ ) )
343+ } ) ;
344+ Arc :: clone ( memory_pool)
345+ }
346+ MemoryPoolType :: FairSpillGlobal => {
347+ static GLOBAL_MEMORY_POOL_FAIR : OnceCell < Arc < dyn MemoryPool > > = OnceCell :: new ( ) ;
348+ let memory_pool = GLOBAL_MEMORY_POOL_FAIR . get_or_init ( || {
349+ Arc :: new ( TrackConsumersPool :: new (
350+ FairSpillPool :: new ( memory_pool_config. pool_size ) ,
351+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
352+ ) )
353+ } ) ;
354+ Arc :: clone ( memory_pool)
355+ }
356+ MemoryPoolType :: GreedyTaskShared | MemoryPoolType :: FairSpillTaskShared => {
357+ let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
358+ let per_task_memory_pool =
359+ memory_pool_map. entry ( task_attempt_id) . or_insert_with ( || {
360+ let pool: Arc < dyn MemoryPool > =
361+ if memory_pool_config. pool_type == MemoryPoolType :: GreedyTaskShared {
362+ Arc :: new ( TrackConsumersPool :: new (
363+ GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
364+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
365+ ) )
366+ } else {
367+ Arc :: new ( TrackConsumersPool :: new (
368+ FairSpillPool :: new ( memory_pool_config. pool_size ) ,
369+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
370+ ) )
371+ } ;
372+ PerTaskMemoryPool :: new ( pool)
373+ } ) ;
374+ per_task_memory_pool. num_plans += 1 ;
375+ Arc :: clone ( & per_task_memory_pool. memory_pool )
376+ }
377+ }
378+ }
379+
227380/// Prepares arrow arrays for output.
228381fn prepare_output (
229382 env : & mut JNIEnv ,
@@ -407,6 +560,22 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
407560) {
408561 try_unwrap_or_throw ( & e, |_| unsafe {
409562 let execution_context = get_execution_context ( exec_context) ;
563+ if execution_context. memory_pool_config . pool_type == MemoryPoolType :: FairSpillTaskShared
564+ || execution_context. memory_pool_config . pool_type == MemoryPoolType :: GreedyTaskShared
565+ {
566+ // Decrement the number of native plans using the per-task shared memory pool, and
567+ // remove the memory pool if the released native plan is the last native plan using it.
568+ let task_attempt_id = execution_context. task_attempt_id ;
569+ let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
570+ if let Some ( per_task_memory_pool) = memory_pool_map. get_mut ( & task_attempt_id) {
571+ per_task_memory_pool. num_plans -= 1 ;
572+ if per_task_memory_pool. num_plans == 0 {
573+ // Drop the memory pool from the per-task memory pool map if there are no
574+ // more native plans using it.
575+ memory_pool_map. remove ( & task_attempt_id) ;
576+ }
577+ }
578+ }
410579 let _: Box < ExecutionContext > = Box :: from_raw ( execution_context) ;
411580 Ok ( ( ) )
412581 } )
0 commit comments