@@ -27,6 +27,9 @@ use datafusion::{
2727 physical_plan:: { display:: DisplayableExecutionPlan , SendableRecordBatchStream } ,
2828 prelude:: { SessionConfig , SessionContext } ,
2929} ;
30+ use datafusion_execution:: memory_pool:: {
31+ FairSpillPool , GreedyMemoryPool , MemoryPool , TrackConsumersPool ,
32+ } ;
3033use futures:: poll;
3134use jni:: {
3235 errors:: Result as JNIResult ,
@@ -53,20 +56,26 @@ use crate::{
5356use datafusion_comet_proto:: spark_operator:: Operator ;
5457use datafusion_common:: ScalarValue ;
5558use futures:: stream:: StreamExt ;
59+ use jni:: sys:: JNI_FALSE ;
5660use jni:: {
5761 objects:: GlobalRef ,
5862 sys:: { jboolean, jdouble, jintArray, jobjectArray, jstring} ,
5963} ;
64+ use std:: num:: NonZeroUsize ;
65+ use std:: sync:: Mutex ;
6066use tokio:: runtime:: Runtime ;
6167
6268use crate :: execution:: operators:: ScanExec ;
6369use crate :: execution:: spark_plan:: SparkPlan ;
6470use log:: info;
71+ use once_cell:: sync:: { Lazy , OnceCell } ;
6572
6673/// Comet native execution context. Kept alive across JNI calls.
6774struct ExecutionContext {
6875 /// The id of the execution context.
6976 pub id : i64 ,
77+ /// Task attempt id
78+ pub task_attempt_id : i64 ,
7079 /// The deserialized Spark plan
7180 pub spark_plan : Operator ,
7281 /// The DataFusion root operator converted from the `spark_plan`
@@ -91,6 +100,51 @@ struct ExecutionContext {
91100 pub explain_native : bool ,
92101 /// Map of metrics name -> jstring object to cache jni_NewStringUTF calls.
93102 pub metrics_jstrings : HashMap < String , Arc < GlobalRef > > ,
103+ /// Memory pool config
104+ pub memory_pool_config : MemoryPoolConfig ,
105+ }
106+
107+ #[ derive( PartialEq , Eq ) ]
108+ enum MemoryPoolType {
109+ Unified ,
110+ Greedy ,
111+ FairSpill ,
112+ GreedyTaskShared ,
113+ FairSpillTaskShared ,
114+ GreedyGlobal ,
115+ FairSpillGlobal ,
116+ }
117+
118+ struct MemoryPoolConfig {
119+ pool_type : MemoryPoolType ,
120+ pool_size : usize ,
121+ }
122+
123+ impl MemoryPoolConfig {
124+ fn new ( pool_type : MemoryPoolType , pool_size : usize ) -> Self {
125+ Self {
126+ pool_type,
127+ pool_size,
128+ }
129+ }
130+ }
131+
132+ /// The per-task memory pools keyed by task attempt id.
133+ static TASK_SHARED_MEMORY_POOLS : Lazy < Mutex < HashMap < i64 , PerTaskMemoryPool > > > =
134+ Lazy :: new ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
135+
136+ struct PerTaskMemoryPool {
137+ memory_pool : Arc < dyn MemoryPool > ,
138+ num_plans : usize ,
139+ }
140+
141+ impl PerTaskMemoryPool {
142+ fn new ( memory_pool : Arc < dyn MemoryPool > ) -> Self {
143+ Self {
144+ memory_pool,
145+ num_plans : 0 ,
146+ }
147+ }
94148}
95149
96150/// Accept serialized query plan and return the address of the native query plan.
@@ -107,8 +161,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
107161 comet_task_memory_manager_obj : JObject ,
108162 batch_size : jint ,
109163 use_unified_memory_manager : jboolean ,
164+ memory_pool_type : jstring ,
110165 memory_limit : jlong ,
166+ memory_limit_per_task : jlong ,
111167 memory_fraction : jdouble ,
168+ task_attempt_id : jlong ,
112169 debug_native : jboolean ,
113170 explain_native : jboolean ,
114171 worker_threads : jint ,
@@ -147,21 +204,27 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
147204 let task_memory_manager =
148205 Arc :: new ( jni_new_global_ref ! ( env, comet_task_memory_manager_obj) ?) ;
149206
207+ let memory_pool_type = env. get_string ( & JString :: from_raw ( memory_pool_type) ) ?. into ( ) ;
208+ let memory_pool_config = parse_memory_pool_config (
209+ use_unified_memory_manager != JNI_FALSE ,
210+ memory_pool_type,
211+ memory_limit,
212+ memory_limit_per_task,
213+ memory_fraction,
214+ ) ?;
215+ let memory_pool =
216+ create_memory_pool ( & memory_pool_config, task_memory_manager, task_attempt_id) ;
217+
150218 // We need to keep the session context alive. Some session state like temporary
151219 // dictionaries are stored in session context. If it is dropped, the temporary
152220 // dictionaries will be dropped as well.
153- let session = prepare_datafusion_session_context (
154- batch_size as usize ,
155- use_unified_memory_manager == 1 ,
156- memory_limit as usize ,
157- memory_fraction,
158- task_memory_manager,
159- ) ?;
221+ let session = prepare_datafusion_session_context ( batch_size as usize , memory_pool) ?;
160222
161223 let plan_creation_time = start. elapsed ( ) ;
162224
163225 let exec_context = Box :: new ( ExecutionContext {
164226 id,
227+ task_attempt_id,
165228 spark_plan,
166229 root_op : None ,
167230 scans : vec ! [ ] ,
@@ -174,6 +237,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
174237 debug_native : debug_native == 1 ,
175238 explain_native : explain_native == 1 ,
176239 metrics_jstrings : HashMap :: new ( ) ,
240+ memory_pool_config,
177241 } ) ;
178242
179243 Ok ( Box :: into_raw ( exec_context) as i64 )
@@ -183,22 +247,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
183247/// Configure DataFusion session context.
184248fn prepare_datafusion_session_context (
185249 batch_size : usize ,
186- use_unified_memory_manager : bool ,
187- memory_limit : usize ,
188- memory_fraction : f64 ,
189- comet_task_memory_manager : Arc < GlobalRef > ,
250+ memory_pool : Arc < dyn MemoryPool > ,
190251) -> CometResult < SessionContext > {
191252 let mut rt_config = RuntimeConfig :: new ( ) . with_disk_manager ( DiskManagerConfig :: NewOs ) ;
192-
193- // Check if we are using unified memory manager integrated with Spark.
194- if use_unified_memory_manager {
195- // Set Comet memory pool for native
196- let memory_pool = CometMemoryPool :: new ( comet_task_memory_manager) ;
197- rt_config = rt_config. with_memory_pool ( Arc :: new ( memory_pool) ) ;
198- } else {
199- // Use the memory pool from DF
200- rt_config = rt_config. with_memory_limit ( memory_limit, memory_fraction)
201- }
253+ rt_config = rt_config. with_memory_pool ( memory_pool) ;
202254
203255 // Get Datafusion configuration from Spark Execution context
204256 // can be configured in Comet Spark JVM using Spark --conf parameters
@@ -225,6 +277,107 @@ fn prepare_datafusion_session_context(
225277 Ok ( session_ctx)
226278}
227279
280+ fn parse_memory_pool_config (
281+ use_unified_memory_manager : bool ,
282+ memory_pool_type : String ,
283+ memory_limit : i64 ,
284+ memory_limit_per_task : i64 ,
285+ memory_fraction : f64 ,
286+ ) -> CometResult < MemoryPoolConfig > {
287+ let memory_pool_config = if use_unified_memory_manager {
288+ MemoryPoolConfig :: new ( MemoryPoolType :: Unified , 0 )
289+ } else {
290+ // Use the memory pool from DF
291+ let pool_size = ( memory_limit as f64 * memory_fraction) as usize ;
292+ let pool_size_per_task = ( memory_limit_per_task as f64 * memory_fraction) as usize ;
293+ match memory_pool_type. as_str ( ) {
294+ "fair_spill_task_shared" => {
295+ MemoryPoolConfig :: new ( MemoryPoolType :: FairSpillTaskShared , pool_size_per_task)
296+ }
297+ "greedy_task_shared" => {
298+ MemoryPoolConfig :: new ( MemoryPoolType :: GreedyTaskShared , pool_size_per_task)
299+ }
300+ "fair_spill_global" => {
301+ MemoryPoolConfig :: new ( MemoryPoolType :: FairSpillGlobal , pool_size)
302+ }
303+ "greedy_global" => MemoryPoolConfig :: new ( MemoryPoolType :: GreedyGlobal , pool_size) ,
304+ "fair_spill" => MemoryPoolConfig :: new ( MemoryPoolType :: FairSpill , pool_size_per_task) ,
305+ "greedy" => MemoryPoolConfig :: new ( MemoryPoolType :: Greedy , pool_size_per_task) ,
306+ _ => {
307+ return Err ( CometError :: Config ( format ! (
308+ "Unsupported memory pool type: {}" ,
309+ memory_pool_type
310+ ) ) )
311+ }
312+ }
313+ } ;
314+ Ok ( memory_pool_config)
315+ }
316+
317+ fn create_memory_pool (
318+ memory_pool_config : & MemoryPoolConfig ,
319+ comet_task_memory_manager : Arc < GlobalRef > ,
320+ task_attempt_id : i64 ,
321+ ) -> Arc < dyn MemoryPool > {
322+ const NUM_TRACKED_CONSUMERS : usize = 10 ;
323+ match memory_pool_config. pool_type {
324+ MemoryPoolType :: Unified => {
325+ // Set Comet memory pool for native
326+ let memory_pool = CometMemoryPool :: new ( comet_task_memory_manager) ;
327+ Arc :: new ( memory_pool)
328+ }
329+ MemoryPoolType :: Greedy => Arc :: new ( TrackConsumersPool :: new (
330+ GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
331+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
332+ ) ) ,
333+ MemoryPoolType :: FairSpill => Arc :: new ( TrackConsumersPool :: new (
334+ FairSpillPool :: new ( memory_pool_config. pool_size ) ,
335+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
336+ ) ) ,
337+ MemoryPoolType :: GreedyGlobal => {
338+ static GLOBAL_MEMORY_POOL_GREEDY : OnceCell < Arc < dyn MemoryPool > > = OnceCell :: new ( ) ;
339+ let memory_pool = GLOBAL_MEMORY_POOL_GREEDY . get_or_init ( || {
340+ Arc :: new ( TrackConsumersPool :: new (
341+ GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
342+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
343+ ) )
344+ } ) ;
345+ Arc :: clone ( memory_pool)
346+ }
347+ MemoryPoolType :: FairSpillGlobal => {
348+ static GLOBAL_MEMORY_POOL_FAIR : OnceCell < Arc < dyn MemoryPool > > = OnceCell :: new ( ) ;
349+ let memory_pool = GLOBAL_MEMORY_POOL_FAIR . get_or_init ( || {
350+ Arc :: new ( TrackConsumersPool :: new (
351+ FairSpillPool :: new ( memory_pool_config. pool_size ) ,
352+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
353+ ) )
354+ } ) ;
355+ Arc :: clone ( memory_pool)
356+ }
357+ MemoryPoolType :: GreedyTaskShared | MemoryPoolType :: FairSpillTaskShared => {
358+ let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
359+ let per_task_memory_pool =
360+ memory_pool_map. entry ( task_attempt_id) . or_insert_with ( || {
361+ let pool: Arc < dyn MemoryPool > =
362+ if memory_pool_config. pool_type == MemoryPoolType :: GreedyTaskShared {
363+ Arc :: new ( TrackConsumersPool :: new (
364+ GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
365+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
366+ ) )
367+ } else {
368+ Arc :: new ( TrackConsumersPool :: new (
369+ FairSpillPool :: new ( memory_pool_config. pool_size ) ,
370+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
371+ ) )
372+ } ;
373+ PerTaskMemoryPool :: new ( pool)
374+ } ) ;
375+ per_task_memory_pool. num_plans += 1 ;
376+ Arc :: clone ( & per_task_memory_pool. memory_pool )
377+ }
378+ }
379+ }
380+
228381/// Prepares arrow arrays for output.
229382fn prepare_output (
230383 env : & mut JNIEnv ,
@@ -408,6 +561,20 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
408561) {
409562 try_unwrap_or_throw ( & e, |_| unsafe {
410563 let execution_context = get_execution_context ( exec_context) ;
564+ if execution_context. memory_pool_config . pool_type == MemoryPoolType :: FairSpillTaskShared {
565+ // Decrement the number of native plans using the per-task shared memory pool, and
566+ // remove the memory pool if the released native plan is the last native plan using it.
567+ let task_attempt_id = execution_context. task_attempt_id ;
568+ let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
569+ if let Some ( per_task_memory_pool) = memory_pool_map. get_mut ( & task_attempt_id) {
570+ per_task_memory_pool. num_plans -= 1 ;
571+ if per_task_memory_pool. num_plans == 0 {
572+ // Drop the memory pool from the per-task memory pool map if there are no
573+ // more native plans using it.
574+ memory_pool_map. remove ( & task_attempt_id) ;
575+ }
576+ }
577+ }
411578 let _: Box < ExecutionContext > = Box :: from_raw ( execution_context) ;
412579 Ok ( ( ) )
413580 } )
0 commit comments