diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 3818e8ee56587..85edf95f1f6bb 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -32,6 +32,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; +use datafusion_common::utils::list_ndims; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use strum::IntoEnumIterator; @@ -498,25 +499,6 @@ impl BuiltinScalarFunction { } } - /// Returns the dimension [`DataType`] of [`DataType::List`] if - /// treated as a N-dimensional array. - /// - /// ## Examples: - /// - /// * `Int64` has dimension 1 - /// * `List(Int64)` has dimension 2 - /// * `List(List(Int64))` has dimension 3 - /// * etc. - fn return_dimension(self, input_expr_type: &DataType) -> u64 { - let mut result: u64 = 1; - let mut current_data_type = input_expr_type; - while let DataType::List(field) = current_data_type { - current_data_type = field.data_type(); - result += 1; - } - result - } - /// Returns the output [`DataType`] of this function /// /// This method should be invoked only after `input_expr_types` have been validated @@ -552,25 +534,30 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { - let mut expr_type = Null; + let mut expr_type: Option = None; let mut max_dims = 0; for input_expr_type in input_expr_types { match input_expr_type { - List(field) => { - if !field.data_type().equals_datatype(&Null) { - let dims = self.return_dimension(input_expr_type); - expr_type = match max_dims.cmp(&dims) { - Ordering::Greater => expr_type, + List(_) => { + let dims = list_ndims(input_expr_type); + if let Some(data_type) = expr_type { + let new_type = match max_dims.cmp(&dims) { + Ordering::Greater => data_type, Ordering::Equal => { - get_wider_type(&expr_type, input_expr_type)? + get_wider_type(&data_type, input_expr_type)? } Ordering::Less => { max_dims = dims; input_expr_type.clone() } }; + expr_type = Some(new_type) + } else { + expr_type = Some(input_expr_type.clone()); + max_dims = dims; } } + DataType::Null => {} _ => { return plan_err!( "The {self} function can only accept list as the args." @@ -579,7 +566,11 @@ impl BuiltinScalarFunction { } } - Ok(expr_type) + if let Some(expr_type) = expr_type { + Ok(expr_type) + } else { + Ok(DataType::Null) + } } BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny @@ -929,9 +920,10 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayConcat => { - Signature::variadic_any(self.volatility()) - } + BuiltinScalarFunction::ArrayConcat => Signature { + type_signature: ArrayConcat, + volatility: self.volatility(), + }, BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 3f07c300e1962..353f24e111258 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -122,6 +122,9 @@ pub enum TypeSignature { /// List dimension of the List/LargeList is equivalent to the number of List. /// List dimension of the non-list is 0. ArrayAndElement, + /// Specialized Signature for ArrayConcat + /// Accept arbitrary arguments but they SHOULD be List/LargeList or Null, and the list dimension MAY NOT be the same. + ArrayConcat, } impl TypeSignature { @@ -155,6 +158,9 @@ impl TypeSignature { TypeSignature::ArrayAndElement => { vec!["ArrayAndElement(List, T)".to_string()] } + TypeSignature::ArrayConcat => { + vec!["ArrayConcat(List / NULL, .., List / NULL)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index dd9449198796a..082cdc9a4d47d 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -29,7 +29,8 @@ use arrow::datatypes::{ }; use datafusion_common::{ - exec_datafusion_err, plan_datafusion_err, plan_err, DataFusionError, Result, + exec_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, + Result, }; /// The type signature of an instantiation of binary operator expression such as @@ -300,6 +301,21 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Result { + data_types.iter().skip(1).try_fold( + data_types.first().unwrap().clone(), + |current_type, other_type| { + let coerced_type = comparison_coercion(¤t_type, other_type); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {current_type:?} to {other_type:?} failed.") + } + }, + ) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is numeric and one is `Utf8`/`LargeUtf8`. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index f95a30e025b49..c331f17e43bfd 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -24,7 +24,7 @@ use arrow::{ use datafusion_common::utils::list_ndims; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; -use super::binary::comparison_coercion; +use super::binary::{comparison_coercion, comparison_coercion_for_iter}; /// Performs type coercion for function arguments. /// @@ -89,18 +89,7 @@ fn get_valid_types( .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::VariadicEqual => { - let new_type = current_types.iter().skip(1).try_fold( - current_types.first().unwrap().clone(), - |acc, x| { - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) - } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") - } - }, - ); - + let new_type = comparison_coercion_for_iter(current_types); match new_type { Ok(new_type) => vec![vec![new_type; current_types.len()]], Err(e) => return Err(e), @@ -149,6 +138,33 @@ fn get_valid_types( return Ok(vec![vec![]]); } } + TypeSignature::ArrayConcat => { + let base_types = current_types + .iter() + .map(datafusion_common::utils::base_type) + .collect::>(); + + let new_base_type = comparison_coercion_for_iter(base_types.as_slice()); + match new_base_type { + Ok(new_base_type) => { + let array_types = current_types + .iter() + .map(|t| { + if t.eq(&DataType::Null) { + t.to_owned() + } else { + datafusion_common::utils::coerced_type_with_base_type_only( + t, + &new_base_type, + ) + } + }) + .collect::>(); + return Ok(vec![array_types]); + } + Err(e) => return Err(e), + } + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index d396581083371..c4d6e7af389e4 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -33,7 +33,7 @@ use datafusion_common::cast::{ as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array, as_list_array, as_null_array, as_string_array, }; -use datafusion_common::utils::{array_into_list_array, list_ndims}; +use datafusion_common::utils::array_into_list_array; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, }; @@ -1084,15 +1084,22 @@ fn concat_internal(args: &[ArrayRef]) -> Result { pub fn array_concat(args: &[ArrayRef]) -> Result { let mut new_args = vec![]; for arg in args { - let ndim = list_ndims(arg.data_type()); - let base_type = datafusion_common::utils::base_type(arg.data_type()); - if ndim == 0 { - return not_impl_err!("Array is not type '{base_type:?}'."); - } else if !base_type.eq(&DataType::Null) { + let data_type = arg.data_type(); + if let DataType::List(_) = data_type { new_args.push(arg.clone()); + } else if data_type.eq(&DataType::Null) { + // Null type is valid. + continue; + } else { + return internal_err!("Expect Array type, found {:?}", data_type); } } + // All the arguments are null, return null + if new_args.is_empty() { + return Ok(new_null_array(&DataType::Null, 0)); + } + concat_internal(new_args.as_slice()) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b38f73ecb8dbd..bda07a2830add 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1609,22 +1609,25 @@ select array_concat(make_array(), make_array(2, 3)); [2, 3] # array_concat scalar function #7 (with empty arrays) +## DuckDB and ClickHouse both return '[[1, 2], [3, 4], []]' query ? select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array())); ---- -[[1, 2], [3, 4]] +[[1, 2], [3, 4], []] # array_concat scalar function #8 (with empty arrays) +## DuckDB return error, ClickHouse return '[[1, 2], [3, 4], [], [], [], [5, 6], [7, 8]]' query ? -select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()), make_array(make_array(), make_array()), make_array(make_array(5, 6), make_array(7, 8))); +select array_concat([[1,2], [3,4]], [[]], [[],[]], [[5,6], [7,8]]); ---- -[[1, 2], [3, 4], [5, 6], [7, 8]] +[[1, 2], [3, 4], [], [], [], [5, 6], [7, 8]] # array_concat scalar function #9 (with empty arrays) +## DuckDB and ClickHouse both return '[[], [1, 2], [3, 4]]' query ? select array_concat(make_array(make_array()), make_array(make_array(1, 2), make_array(3, 4))); ---- -[[1, 2], [3, 4]] +[[], [1, 2], [3, 4]] # array_cat scalar function #10 (function alias `array_concat`) query ?? @@ -1818,6 +1821,86 @@ select array_concat(make_array(column3), column1, column2) from arrays_values_v2 [, 11, 12] [] +# array concat with nulls +query ? +select array_concat([1,2,3], null); +---- +[1, 2, 3] + +query ? +select array_concat(null, [1,2,3]); +---- +[1, 2, 3] + +query ? +select array_concat([1, null, 2], [3, 4, null]); +---- +[1, , 2, 3, 4, ] + +query ? +select array_concat([1,2], [null,3]); +---- +[1, 2, , 3] + +query ? +select array_concat([[1,2]], [[null,3]]); +---- +[[1, 2], [, 3]] + +query ? +select array_concat([1, null], [[null, 2]]); +---- +[[1, ], [, 2]] + +query ? +select array_concat([1, null], [null]); +---- +[1, , ] + +query ? +select array_concat(null, null); +---- +NULL + +query ? +select array_concat([], null); +---- +[] + +query ? +select array_concat([], []); +---- +[] + +query ? +select array_concat([null], [null]); +---- +[, ] + +# 3D null + 1D + 2D empty +query ? +select array_concat([[[null]]], [1, 2], [[]]); +---- +[[[]], [[1, 2]], [[]]] + +# 1D + 2D + 3D empty +query ? +select array_concat([], [[]], [[[]]]); +---- +[[[]], [[]], [[]]] + +# 1D + 2D + 3D null +query ? +select array_concat([null], [[null]], [[[null]]]); +---- +[[[]], [[]], [[]]] + +# 0D + 1D + 2D + 3D null +query ? +select array_concat(null, [null], [[null]], [[[null]]]); +---- +[[[]], [[]], [[]]] + ## array_position (aliases: `list_position`, `array_indexof`, `list_indexof`) # array_position scalar function #1