diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 85b478bc0d6d9..8f7968a37353c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -53,7 +53,7 @@ use arrow::{ }; use arrow_array::cast::as_list_array; use arrow_array::{ArrowNativeTypeOp, Scalar}; -use arrow_buffer::{Buffer, NullBuffer}; +use arrow_buffer::NullBuffer; /// A dynamically typed, nullable single value, (the single-valued counter-part /// to arrow's [`Array`]) @@ -1402,121 +1402,6 @@ impl ScalarValue { }}; } - fn build_struct_array( - scalars: impl IntoIterator, - ) -> Result { - let arrays = scalars - .into_iter() - .map(|s| s.to_array()) - .collect::>>()?; - - let first_struct = arrays[0].as_struct_opt(); - if first_struct.is_none() { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected ScalarValue::Struct, got {:?}", - arrays[0].clone() - ); - } - - let mut valid = BooleanBufferBuilder::new(arrays.len()); - - let first_struct = first_struct.unwrap(); - valid.append(first_struct.is_valid(0)); - - let mut column_values: Vec> = - vec![Vec::with_capacity(arrays.len()); first_struct.num_columns()]; - - for (i, v) in first_struct.columns().iter().enumerate() { - // ScalarValue::Struct contains a single element in each column. - let sv = ScalarValue::try_from_array(v, 0)?; - column_values[i].push(sv); - } - - for arr in arrays.iter().skip(1) { - if let Some(struct_array) = arr.as_struct_opt() { - valid.append(struct_array.is_valid(0)); - - for (i, v) in struct_array.columns().iter().enumerate() { - // ScalarValue::Struct contains a single element in each column. - let sv = ScalarValue::try_from_array(v, 0)?; - column_values[i].push(sv); - } - } else { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected ScalarValue::Struct, got {arr:?}" - ); - } - } - - let column_fields = first_struct.fields().to_vec(); - - let mut data = vec![]; - for (field, values) in - column_fields.into_iter().zip(column_values.into_iter()) - { - let field = field.to_owned(); - let array = ScalarValue::iter_to_array(values.into_iter())?; - data.push((field, array)); - } - - let bool_buffer = valid.finish(); - let buffer: Buffer = bool_buffer.values().into(); - Ok(Arc::new(StructArray::from((data, buffer)))) - } - - fn build_list_array( - scalars: impl IntoIterator, - ) -> Result { - let arrays = scalars - .into_iter() - .map(|s| s.to_array()) - .collect::>>()?; - - let capacity = Capacities::Array( - arrays - .iter() - .filter_map(|arr| { - if !arr.is_null(0) { - Some(arr.len()) - } else { - None - } - }) - .sum(), - ); - - // ScalarValue::List contains a single element ListArray. - let nulls = arrays - .iter() - .map(|arr| arr.is_null(0)) - .collect::>(); - let arrays_data = arrays - .iter() - .filter(|arr| !arr.is_null(0)) - .map(|arr| arr.to_data()) - .collect::>(); - - let arrays_ref = arrays_data.iter().collect::>(); - let mut mutable = - MutableArrayData::with_capacities(arrays_ref, true, capacity); - - // ScalarValue::List contains a single element ListArray. - let mut index = 0; - for is_null in nulls.into_iter() { - if is_null { - mutable.extend_nulls(1); - } else { - // mutable array contains non-null elements - mutable.extend(index, 0, 1); - index += 1; - } - } - let data = mutable.freeze(); - Ok(arrow_array::make_array(data)) - } - let array: ArrayRef = match &data_type { DataType::Decimal128(precision, scale) => { let decimal_array = @@ -1591,10 +1476,32 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::Struct(_) => build_struct_array(scalars)?, - DataType::List(_) - | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) => build_list_array(scalars)?, + DataType::FixedSizeList(_, _) => { + // arrow::compute::concat does not allow inconsistent types including the size of FixedSizeList. + // The length of nulls here we got is 1, so we need to resize the length of nulls to + // the length of non-nulls. + let mut arrays = + scalars.map(|s| s.to_array()).collect::>>()?; + let first_non_null_data_type = arrays + .iter() + .find(|sv| !sv.is_null(0)) + .map(|sv| sv.data_type().to_owned()); + if let Some(DataType::FixedSizeList(f, l)) = first_non_null_data_type { + for array in arrays.iter_mut() { + if array.is_null(0) { + *array = + Arc::new(FixedSizeListArray::new_null(f.clone(), l, 1)); + } + } + } + let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); + arrow::compute::concat(arrays.as_slice())? + } + DataType::List(_) | DataType::LargeList(_) | DataType::Struct(_) => { + let arrays = scalars.map(|s| s.to_array()).collect::>>()?; + let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); + arrow::compute::concat(arrays.as_slice())? + } DataType::Dictionary(key_type, value_type) => { // create the values array let value_scalars = scalars @@ -3529,6 +3436,44 @@ mod tests { .collect() } + #[test] + fn test_iter_to_array_fixed_size_list() { + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let f1 = Arc::new(FixedSizeListArray::new( + field.clone(), + 3, + Arc::new(Int32Array::from(vec![1, 2, 3])), + None, + )); + let f2 = Arc::new(FixedSizeListArray::new( + field.clone(), + 3, + Arc::new(Int32Array::from(vec![4, 5, 6])), + None, + )); + let f_nulls = Arc::new(FixedSizeListArray::new_null(field, 1, 1)); + + let scalars = vec![ + ScalarValue::FixedSizeList(f_nulls.clone()), + ScalarValue::FixedSizeList(f1), + ScalarValue::FixedSizeList(f2), + ScalarValue::FixedSizeList(f_nulls), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + + let expected = FixedSizeListArray::from_iter_primitive::( + vec![ + None, + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + None, + ], + 3, + ); + assert_eq!(array.as_ref(), &expected); + } + #[test] fn test_iter_to_array_struct() { let s1 = StructArray::from(vec![