2222use std:: fmt:: { self , Debug } ;
2323use std:: ops:: Sub ;
2424
25+ use arrow:: datatypes:: ArrowNativeType ;
2526use hashbrown:: hash_table:: Entry :: { Occupied , Vacant } ;
2627use hashbrown:: HashTable ;
2728
@@ -254,39 +255,50 @@ impl JoinHashMapType for JoinHashMapU64 {
254255// Type of offsets for obtaining indices from JoinHashMap.
255256pub ( crate ) type JoinHashMapOffset = ( usize , Option < u64 > ) ;
256257
257- // Macro for traversing chained values with limit.
258- // Early returns in case of reaching output tuples limit.
259- macro_rules! chain_traverse {
260- (
261- $input_indices: ident, $match_indices: ident,
262- $hash_values: ident, $next_chain: ident,
263- $input_idx: ident, $chain_idx: ident, $remaining_output: ident, $one: ident, $zero: ident
264- ) => { {
265- // now `one` and `zero` are in scope from the outer function
266- let mut match_row_idx = $chain_idx - $one;
267- loop {
268- $match_indices. push( match_row_idx. into( ) ) ;
269- $input_indices. push( $input_idx as u32 ) ;
270- $remaining_output -= 1 ;
271-
272- let next = $next_chain[ match_row_idx. into( ) as usize ] ;
273-
274- if $remaining_output == 0 {
275- // we compare against `zero` (of type T) here too
276- let next_offset = if $input_idx == $hash_values. len( ) - 1 && next == $zero
277- {
278- None
279- } else {
280- Some ( ( $input_idx, Some ( next. into( ) ) ) )
281- } ;
282- return ( $input_indices, $match_indices, next_offset) ;
283- }
284- if next == $zero {
285- break ;
286- }
287- match_row_idx = next - $one;
258+ /// Traverses the chain of matching indices, collecting results up to the remaining limit.
259+ /// Returns `Some(offset)` if the limit was reached and there are more results to process,
260+ /// or `None` if the chain was fully traversed.
261+ #[ inline( always) ]
262+ fn traverse_chain < T > (
263+ next_chain : & [ T ] ,
264+ input_idx : usize ,
265+ start_chain_idx : T ,
266+ remaining : & mut usize ,
267+ input_indices : & mut Vec < u32 > ,
268+ match_indices : & mut Vec < u64 > ,
269+ is_last_input : bool ,
270+ ) -> Option < JoinHashMapOffset >
271+ where
272+ T : Copy + TryFrom < usize > + PartialOrd + Into < u64 > + Sub < Output = T > ,
273+ <T as TryFrom < usize > >:: Error : Debug ,
274+ T : ArrowNativeType ,
275+ {
276+ let zero = T :: usize_as ( 0 ) ;
277+ let one = T :: usize_as ( 1 ) ;
278+ let mut match_row_idx = start_chain_idx - one;
279+
280+ loop {
281+ match_indices. push ( match_row_idx. into ( ) ) ;
282+ input_indices. push ( input_idx as u32 ) ;
283+ * remaining -= 1 ;
284+
285+ let next = next_chain[ match_row_idx. into ( ) as usize ] ;
286+
287+ if * remaining == 0 {
288+ // Limit reached - return offset for next call
289+ return if is_last_input && next == zero {
290+ // Finished processing the last input row
291+ None
292+ } else {
293+ Some ( ( input_idx, Some ( next. into ( ) ) ) )
294+ } ;
288295 }
289- } } ;
296+ if next == zero {
297+ // End of chain
298+ return None ;
299+ }
300+ match_row_idx = next - one;
301+ }
290302}
291303
292304pub fn update_from_iter < ' a , T > (
@@ -380,10 +392,10 @@ pub fn get_matched_indices_with_limit_offset<T>(
380392where
381393 T : Copy + TryFrom < usize > + PartialOrd + Into < u64 > + Sub < Output = T > ,
382394 <T as TryFrom < usize > >:: Error : Debug ,
395+ T : ArrowNativeType ,
383396{
384397 let mut input_indices = Vec :: with_capacity ( limit) ;
385398 let mut match_indices = Vec :: with_capacity ( limit) ;
386- let zero = T :: try_from ( 0 ) . unwrap ( ) ;
387399 let one = T :: try_from ( 1 ) . unwrap ( ) ;
388400
389401 // Check if hashmap consists of unique values
@@ -409,47 +421,49 @@ where
409421
410422 // Calculate initial `hash_values` index before iterating
411423 let to_skip = match offset {
412- // None `initial_next_idx` indicates that `initial_idx` processing has'n been started
424+ // None `initial_next_idx` indicates that `initial_idx` processing hasn't been started
413425 ( idx, None ) => idx,
414426 // Zero `initial_next_idx` indicates that `initial_idx` has been processed during
415427 // previous iteration, and it should be skipped
416428 ( idx, Some ( 0 ) ) => idx + 1 ,
417429 // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`,
418430 // to start with the next index
419431 ( idx, Some ( next_idx) ) => {
420- let next_idx: T = T :: try_from ( next_idx as usize ) . unwrap ( ) ;
421- chain_traverse ! (
422- input_indices,
423- match_indices,
424- hash_values,
432+ let next_idx: T = T :: usize_as ( next_idx as usize ) ;
433+ let is_last = idx == hash_values. len ( ) - 1 ;
434+ if let Some ( next_offset) = traverse_chain (
425435 next_chain,
426436 idx,
427437 next_idx,
428- remaining_output,
429- one,
430- zero
431- ) ;
438+ & mut remaining_output,
439+ & mut input_indices,
440+ & mut match_indices,
441+ is_last,
442+ ) {
443+ return ( input_indices, match_indices, Some ( next_offset) ) ;
444+ }
432445 idx + 1
433446 }
434447 } ;
435448
436- let mut row_idx = to_skip;
437- for & hash in & hash_values[ to_skip..] {
449+ let hash_values_len = hash_values. len ( ) ;
450+ for ( i, & hash) in hash_values[ to_skip..] . iter ( ) . enumerate ( ) {
451+ let row_idx = to_skip + i;
438452 if let Some ( ( _, idx) ) = map. find ( hash, |( h, _) | hash == * h) {
439453 let idx: T = * idx;
440- chain_traverse ! (
441- input_indices,
442- match_indices,
443- hash_values,
454+ let is_last = row_idx == hash_values_len - 1 ;
455+ if let Some ( next_offset) = traverse_chain (
444456 next_chain,
445457 row_idx,
446458 idx,
447- remaining_output,
448- one,
449- zero
450- ) ;
459+ & mut remaining_output,
460+ & mut input_indices,
461+ & mut match_indices,
462+ is_last,
463+ ) {
464+ return ( input_indices, match_indices, Some ( next_offset) ) ;
465+ }
451466 }
452- row_idx += 1 ;
453467 }
454468 ( input_indices, match_indices, None )
455469}
0 commit comments