Skip to content

Commit 8e8ae88

Browse files
Weijun-Happletreeisyellow
authored andcommitted
Minor: convert marcro list-slice and slice to function (apache#8424)
* remove marcro list-slice * fix cast dyn Array * remove macro slice
1 parent 4679f60 commit 8e8ae88

1 file changed

Lines changed: 89 additions & 92 deletions

File tree

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 89 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! Array expressions
1919
2020
use std::any::type_name;
21+
use std::cmp::Ordering;
2122
use std::collections::HashSet;
2223
use std::sync::Arc;
2324

@@ -377,111 +378,107 @@ fn return_empty(return_null: bool, data_type: DataType) -> Arc<dyn Array> {
377378
}
378379
}
379380

380-
macro_rules! list_slice {
381-
($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{
382-
let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
383-
if $I == 0 && $J == 0 || $ARRAY.is_empty() {
384-
return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone());
385-
}
381+
fn list_slice<T: Array + 'static>(
382+
array: &dyn Array,
383+
i: i64,
384+
j: i64,
385+
return_element: bool,
386+
) -> ArrayRef {
387+
let array = array.as_any().downcast_ref::<T>().unwrap();
386388

387-
let i = if $I < 0 {
388-
if $I.abs() as usize > array.len() {
389-
return return_empty(true, $ARRAY.data_type().clone());
390-
}
389+
let array_type = array.data_type().clone();
391390

392-
(array.len() as i64 + $I + 1) as usize
393-
} else {
394-
if $I == 0 {
395-
1
396-
} else {
397-
$I as usize
398-
}
399-
};
400-
let j = if $J < 0 {
401-
if $J.abs() as usize > array.len() {
402-
return return_empty(true, $ARRAY.data_type().clone());
391+
if i == 0 && j == 0 || array.is_empty() {
392+
return return_empty(return_element, array_type);
393+
}
394+
395+
let i = match i.cmp(&0) {
396+
Ordering::Less => {
397+
if i.unsigned_abs() > array.len() as u64 {
398+
return return_empty(true, array_type);
403399
}
404400

405-
if $RETURN_ELEMENT {
406-
(array.len() as i64 + $J + 1) as usize
407-
} else {
408-
(array.len() as i64 + $J) as usize
401+
(array.len() as i64 + i + 1) as usize
402+
}
403+
Ordering::Equal => 1,
404+
Ordering::Greater => i as usize,
405+
};
406+
407+
let j = match j.cmp(&0) {
408+
Ordering::Less => {
409+
if j.unsigned_abs() as usize > array.len() {
410+
return return_empty(true, array_type);
409411
}
410-
} else {
411-
if $J == 0 {
412-
1
412+
if return_element {
413+
(array.len() as i64 + j + 1) as usize
413414
} else {
414-
if $J as usize > array.len() {
415-
array.len()
416-
} else {
417-
$J as usize
418-
}
415+
(array.len() as i64 + j) as usize
419416
}
420-
};
421-
422-
if i > j || i as usize > $ARRAY.len() {
423-
return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone())
424-
} else {
425-
Arc::new(array.slice((i - 1), (j + 1 - i)))
426417
}
427-
}};
418+
Ordering::Equal => 1,
419+
Ordering::Greater => j.min(array.len() as i64) as usize,
420+
};
421+
422+
if i > j || i > array.len() {
423+
return_empty(return_element, array_type)
424+
} else {
425+
Arc::new(array.slice(i - 1, j + 1 - i))
426+
}
428427
}
429428

430-
macro_rules! slice {
431-
($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{
432-
let sliced_array: Vec<Arc<dyn Array>> = $ARRAY
429+
fn slice<T: Array + 'static>(
430+
array: &ListArray,
431+
key: &Int64Array,
432+
extra_key: &Int64Array,
433+
return_element: bool,
434+
) -> Result<Arc<dyn Array>> {
435+
let sliced_array: Vec<Arc<dyn Array>> = array
436+
.iter()
437+
.zip(key.iter())
438+
.zip(extra_key.iter())
439+
.map(|((arr, i), j)| match (arr, i, j) {
440+
(Some(arr), Some(i), Some(j)) => list_slice::<T>(&arr, i, j, return_element),
441+
(Some(arr), None, Some(j)) => list_slice::<T>(&arr, 1i64, j, return_element),
442+
(Some(arr), Some(i), None) => {
443+
list_slice::<T>(&arr, i, arr.len() as i64, return_element)
444+
}
445+
(Some(arr), None, None) if !return_element => arr.clone(),
446+
_ => return_empty(return_element, array.value_type()),
447+
})
448+
.collect();
449+
450+
// concat requires input of at least one array
451+
if sliced_array.is_empty() {
452+
Ok(return_empty(return_element, array.value_type()))
453+
} else {
454+
let vec = sliced_array
433455
.iter()
434-
.zip($KEY.iter())
435-
.zip($EXTRA_KEY.iter())
436-
.map(|((arr, i), j)| match (arr, i, j) {
437-
(Some(arr), Some(i), Some(j)) => {
438-
list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE)
439-
}
440-
(Some(arr), None, Some(j)) => {
441-
list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE)
442-
}
443-
(Some(arr), Some(i), None) => {
444-
list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE)
445-
}
446-
(Some(arr), None, None) if !$RETURN_ELEMENT => arr,
447-
_ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()),
448-
})
449-
.collect();
456+
.map(|a| a.as_ref())
457+
.collect::<Vec<&dyn Array>>();
458+
let mut i: i32 = 0;
459+
let mut offsets = vec![i];
460+
offsets.extend(
461+
vec.iter()
462+
.map(|a| {
463+
i += a.len() as i32;
464+
i
465+
})
466+
.collect::<Vec<_>>(),
467+
);
468+
let values = compute::concat(vec.as_slice()).unwrap();
450469

451-
// concat requires input of at least one array
452-
if sliced_array.is_empty() {
453-
Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type()))
470+
if return_element {
471+
Ok(values)
454472
} else {
455-
let vec = sliced_array
456-
.iter()
457-
.map(|a| a.as_ref())
458-
.collect::<Vec<&dyn Array>>();
459-
let mut i: i32 = 0;
460-
let mut offsets = vec![i];
461-
offsets.extend(
462-
vec.iter()
463-
.map(|a| {
464-
i += a.len() as i32;
465-
i
466-
})
467-
.collect::<Vec<_>>(),
468-
);
469-
let values = compute::concat(vec.as_slice()).unwrap();
470-
471-
if $RETURN_ELEMENT {
472-
Ok(values)
473-
} else {
474-
let field =
475-
Arc::new(Field::new("item", $ARRAY.value_type().clone(), true));
476-
Ok(Arc::new(ListArray::try_new(
477-
field,
478-
OffsetBuffer::new(offsets.into()),
479-
values,
480-
None,
481-
)?))
482-
}
473+
let field = Arc::new(Field::new("item", array.value_type(), true));
474+
Ok(Arc::new(ListArray::try_new(
475+
field,
476+
OffsetBuffer::new(offsets.into()),
477+
values,
478+
None,
479+
)?))
483480
}
484-
}};
481+
}
485482
}
486483

487484
fn define_array_slice(
@@ -492,7 +489,7 @@ fn define_array_slice(
492489
) -> Result<ArrayRef> {
493490
macro_rules! array_function {
494491
($ARRAY_TYPE:ident) => {
495-
slice!(list_array, key, extra_key, return_element, $ARRAY_TYPE)
492+
slice::<$ARRAY_TYPE>(list_array, key, extra_key, return_element)
496493
};
497494
}
498495
call_array_function!(list_array.value_type(), true)

0 commit comments

Comments
 (0)