diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 763a4e6539fd..d776d07775fd 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -401,25 +401,35 @@ fn get_valid_types( let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList); let mut list_sizes = Vec::with_capacity(arguments.len()); let mut element_types = Vec::with_capacity(arguments.len()); + let mut nested_item_nullability = Vec::with_capacity(arguments.len()); for (argument, current_type) in arguments.iter().zip(current_types.iter()) { match argument { - ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (), + ArrayFunctionArgument::Index | ArrayFunctionArgument::String => { + nested_item_nullability.push(None); + } ArrayFunctionArgument::Element => { - element_types.push(current_type.clone()) + element_types.push(current_type.clone()); + nested_item_nullability.push(None); } ArrayFunctionArgument::Array => match current_type { - DataType::Null => element_types.push(DataType::Null), + DataType::Null => { + element_types.push(DataType::Null); + nested_item_nullability.push(None); + } DataType::List(field) => { element_types.push(field.data_type().clone()); + nested_item_nullability.push(Some(field.is_nullable())); fixed_size = false; } DataType::LargeList(field) => { element_types.push(field.data_type().clone()); + nested_item_nullability.push(Some(field.is_nullable())); large_list = true; fixed_size = false; } DataType::FixedSizeList(field, size) => { element_types.push(field.data_type().clone()); + nested_item_nullability.push(Some(field.is_nullable())); list_sizes.push(*size) } arg_type => { @@ -429,33 +439,49 @@ fn get_valid_types( } } + debug_assert_eq!(nested_item_nullability.len(), arguments.len()); + let Some(element_type) = type_union_resolution(&element_types) else { return Ok(vec![vec![]]); }; if !fixed_size { list_sizes.clear() - } + }; let mut list_sizes = list_sizes.into_iter(); - let valid_types = arguments.iter().zip(current_types.iter()).map( - |(argument_type, current_type)| match argument_type { - ArrayFunctionArgument::Index => DataType::Int64, - ArrayFunctionArgument::String => DataType::Utf8, - ArrayFunctionArgument::Element => element_type.clone(), - ArrayFunctionArgument::Array => { - if current_type.is_null() { - DataType::Null - } else if large_list { - DataType::new_large_list(element_type.clone(), true) - } else if let Some(size) = list_sizes.next() { - DataType::new_fixed_size_list(element_type.clone(), size, true) - } else { - DataType::new_list(element_type.clone(), true) + let valid_types = arguments + .iter() + .zip(current_types.iter()) + .zip(nested_item_nullability) + .map(|((argument_type, current_type), is_nested_item_nullable)| { + match argument_type { + ArrayFunctionArgument::Index => DataType::Int64, + ArrayFunctionArgument::String => DataType::Utf8, + ArrayFunctionArgument::Element => element_type.clone(), + ArrayFunctionArgument::Array => { + if current_type.is_null() { + DataType::Null + } else if large_list { + DataType::new_large_list( + element_type.clone(), + is_nested_item_nullable.unwrap_or(true), + ) + } else if let Some(size) = list_sizes.next() { + DataType::new_fixed_size_list( + element_type.clone(), + size, + is_nested_item_nullable.unwrap_or(true), + ) + } else { + DataType::new_list( + element_type.clone(), + is_nested_item_nullable.unwrap_or(true), + ) + } } } - }, - ); + }); Ok(vec![valid_types.collect()]) } @@ -1343,6 +1369,18 @@ mod tests { vec![vec![]] ); + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, false), + DataType::new_list(DataType::Int32, false), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, false), + DataType::new_list(DataType::Int64, false), + ]] + ); + Ok(()) } }