diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 36b00b65e285..7b93fd5b5663 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, - as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, + as_fixed_size_binary_array, as_fixed_size_list_array, }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; @@ -38,7 +38,7 @@ use crate::utils::{ array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, }; use arrow::compute::kernels::numeric::*; -use arrow::util::display::{ArrayFormatter, FormatOptions}; +use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, @@ -52,6 +52,8 @@ use arrow::{ }, }; use arrow_array::cast::as_list_array; +use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_buffer::{Buffer, NullBuffer}; /// A dynamically typed, nullable single value, (the single-valued counter-part /// to arrow's [`Array`]) @@ -152,6 +154,8 @@ pub enum ScalarValue { List(Arc), /// The array must be a LargeListArray with length 1. LargeList(Arc), + /// Represents a single element of a [`StructArray`] as an [`ArrayRef`] + Struct(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -189,8 +193,6 @@ pub enum ScalarValue { DurationMicrosecond(Option), /// Duration in nanoseconds DurationNanosecond(Option), - /// struct of nested ScalarValue - Struct(Option>, Fields), /// Dictionary type: index type and value Dictionary(Box, Box), } @@ -255,6 +257,8 @@ impl PartialEq for ScalarValue { (List(_), _) => false, (LargeList(v1), LargeList(v2)) => v1.eq(v2), (LargeList(_), _) => false, + (Struct(v1), Struct(v2)) => v1.eq(v2), + (Struct(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -289,8 +293,6 @@ impl PartialEq for ScalarValue { (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), (IntervalMonthDayNano(_), _) => false, - (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (Struct(_, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, (Null, Null) => true, @@ -372,6 +374,10 @@ impl PartialOrd for ScalarValue { partial_cmp_list(arr1.as_ref(), arr2.as_ref()) } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, + (Struct(struct_arr1), Struct(struct_arr2)) => { + partial_cmp_struct(struct_arr1, struct_arr2) + } + (Struct(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -412,14 +418,6 @@ impl PartialOrd for ScalarValue { (DurationMicrosecond(_), _) => None, (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), (DurationNanosecond(_), _) => None, - (Struct(v1, t1), Struct(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None - } - } - (Struct(_, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) if k1 == k2 { @@ -473,6 +471,34 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { Some(Ordering::Equal) } +fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option { + if s1.len() != s2.len() { + return None; + } + + if s1.data_type() != s2.data_type() { + return None; + } + + for col_index in 0..s1.num_columns() { + let arr1 = s1.column(col_index); + let arr2 = s2.column(col_index); + + let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -527,13 +553,16 @@ impl std::hash::Hash for ScalarValue { FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), List(arr) => { - hash_list(arr.to_owned() as ArrayRef, state); + hash_nested_array(arr.to_owned() as ArrayRef, state); } LargeList(arr) => { - hash_list(arr.to_owned() as ArrayRef, state); + hash_nested_array(arr.to_owned() as ArrayRef, state); } FixedSizeList(arr) => { - hash_list(arr.to_owned() as ArrayRef, state); + hash_nested_array(arr.to_owned() as ArrayRef, state); + } + Struct(arr) => { + hash_nested_array(arr.to_owned() as ArrayRef, state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -552,10 +581,6 @@ impl std::hash::Hash for ScalarValue { IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), - Struct(v, t) => { - v.hash(state); - t.hash(state); - } Dictionary(k, v) => { k.hash(state); v.hash(state); @@ -566,7 +591,7 @@ impl std::hash::Hash for ScalarValue { } } -fn hash_list(arr: ArrayRef, state: &mut H) { +fn hash_nested_array(arr: ArrayRef, state: &mut H) { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -962,6 +987,7 @@ impl ScalarValue { ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), + ScalarValue::Struct(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -985,7 +1011,6 @@ impl ScalarValue { ScalarValue::DurationNanosecond(_) => { DataType::Duration(TimeUnit::Nanosecond) } - ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -1167,6 +1192,7 @@ impl ScalarValue { ScalarValue::List(arr) => arr.len() == arr.null_count(), ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), + ScalarValue::Struct(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1184,7 +1210,6 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), - ScalarValue::Struct(v, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -1377,6 +1402,70 @@ impl ScalarValue { }}; } + fn build_struct_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; + + let first_struct = arrays[0].as_struct_opt(); + if first_struct.is_none() { + return _internal_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected ScalarValue::Struct, got {:?}", + arrays[0].clone() + ); + } + + let mut valid = BooleanBufferBuilder::new(arrays.len()); + + let first_struct = first_struct.unwrap(); + valid.append(first_struct.is_valid(0)); + + let mut column_values: Vec> = + vec![Vec::with_capacity(arrays.len()); first_struct.num_columns()]; + + for (i, v) in first_struct.columns().iter().enumerate() { + // ScalarValue::Struct contains a single element in each column. + let sv = ScalarValue::try_from_array(v, 0)?; + column_values[i].push(sv); + } + + for arr in arrays.iter().skip(1) { + if let Some(struct_array) = arr.as_struct_opt() { + valid.append(struct_array.is_valid(0)); + + for (i, v) in struct_array.columns().iter().enumerate() { + // ScalarValue::Struct contains a single element in each column. + let sv = ScalarValue::try_from_array(v, 0)?; + column_values[i].push(sv); + } + } else { + return _internal_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected ScalarValue::Struct, got {arr:?}" + ); + } + } + + let column_fields = first_struct.fields().to_vec(); + + let mut data = vec![]; + for (field, values) in + column_fields.into_iter().zip(column_values.into_iter()) + { + let field = field.to_owned(); + let array = ScalarValue::iter_to_array(values.into_iter())?; + data.push((field, array)); + } + + let bool_buffer = valid.finish(); + let buffer: Buffer = bool_buffer.values().into(); + Ok(Arc::new(StructArray::from((data, buffer)))) + } + fn build_list_array( scalars: impl IntoIterator, ) -> Result { @@ -1483,56 +1572,10 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } + DataType::Struct(_) => build_struct_array(scalars)?, DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => build_list_array(scalars)?, - DataType::Struct(fields) => { - // Initialize a Vector to store the ScalarValues for each column - let mut columns: Vec> = - (0..fields.len()).map(|_| Vec::new()).collect(); - - // null mask - let mut null_mask_builder = BooleanBuilder::new(); - - // Iterate over scalars to populate the column scalars for each row - for scalar in scalars { - if let ScalarValue::Struct(values, fields) = scalar { - match values { - Some(values) => { - // Push value for each field - for (column, value) in columns.iter_mut().zip(values) { - column.push(value.clone()); - } - null_mask_builder.append_value(false); - } - None => { - // Push NULL of the appropriate type for each field - for (column, field) in - columns.iter_mut().zip(fields.as_ref()) - { - column - .push(ScalarValue::try_from(field.data_type())?); - } - null_mask_builder.append_value(true); - } - }; - } else { - return _internal_err!("Expected Struct but found: {scalar}"); - }; - } - - // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays - let field_values = fields - .iter() - .zip(columns) - .map(|(field, column)| { - Ok((field.clone(), Self::iter_to_array(column)?)) - }) - .collect::>>()?; - - let array = StructArray::from(field_values); - arrow::compute::nullif(&array, &null_mask_builder.finish())? - } DataType::Dictionary(key_type, value_type) => { // create the values array let value_scalars = scalars @@ -1941,6 +1984,9 @@ impl ScalarValue { ScalarValue::FixedSizeList(arr) => { Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } + ScalarValue::Struct(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -2032,23 +2078,6 @@ impl ScalarValue { e, size ), - ScalarValue::Struct(values, fields) => match values { - Some(values) => { - let field_values = fields - .iter() - .zip(values.iter()) - .map(|(field, value)| { - Ok((field.clone(), value.to_array_of_size(size)?)) - }) - .collect::>>()?; - - Arc::new(StructArray::from(field_values)) - } - None => { - let dt = self.data_type(); - new_null_array(&dt, size) - } - }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { @@ -2205,7 +2234,7 @@ impl ScalarValue { typed_cast!(array, index, LargeStringArray, LargeUtf8)? } DataType::List(_) => { - let list_array = as_list_array(array); + let list_array = array.as_list::(); let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. let arr = Arc::new(array_into_list_array(nested_array)); @@ -2296,15 +2325,9 @@ impl ScalarValue { Self::Dictionary(key_type.clone(), Box::new(value)) } - DataType::Struct(fields) => { - let array = as_struct_array(array)?; - let mut field_values: Vec = Vec::new(); - for col_index in 0..array.num_columns() { - let col_array = array.column(col_index); - let col_scalar = ScalarValue::try_from_array(col_array, index)?; - field_values.push(col_scalar); - } - Self::Struct(Some(field_values), fields.clone()) + DataType::Struct(_) => { + let a = array.slice(index, 1); + Self::Struct(Arc::new(a.as_struct().to_owned())) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; @@ -2515,6 +2538,9 @@ impl ScalarValue { ScalarValue::FixedSizeList(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } + ScalarValue::Struct(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? } @@ -2566,9 +2592,6 @@ impl ScalarValue { ScalarValue::DurationNanosecond(val) => { eq_array_primitive!(array, index, DurationNanosecondArray, val)? } - ScalarValue::Struct(_, _) => { - return _not_impl_err!("Struct is not supported yet") - } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { DataType::Int8 => get_dict_value::(array, index)?, @@ -2645,20 +2668,7 @@ impl ScalarValue { ScalarValue::List(arr) => arr.get_array_memory_size(), ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), - ScalarValue::Struct(vals, fields) => { - vals.as_ref() - .map(|vals| { - vals.iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) - .sum::() - + (std::mem::size_of::() * vals.capacity()) - }) - .unwrap_or_default() - // `fields` is boxed, so it is NOT already included in `self` - + std::mem::size_of_val(fields) - + (std::mem::size_of::() * fields.len()) - + fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::() - } + ScalarValue::Struct(arr) => arr.get_array_memory_size(), ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() @@ -2744,6 +2754,26 @@ impl From> for ScalarValue { } } +/// Wrapper to create ScalarValue::Struct for convenience +impl From> for ScalarValue { + fn from(value: Vec<(&str, ScalarValue)>) -> Self { + let (fields, scalars): (SchemaBuilder, Vec<_>) = value + .into_iter() + .map(|(name, scalar)| (Field::new(name, scalar.data_type(), false), scalar)) + .unzip(); + + let arrays = scalars + .into_iter() + .map(|scalar| scalar.to_array().unwrap()) + .collect::>(); + + let fields = fields.finish().fields; + let struct_array = StructArray::try_new(fields, arrays, None).unwrap(); + + Self::Struct(Arc::new(struct_array)) + } +} + impl FromStr for ScalarValue { type Err = Infallible; @@ -2758,14 +2788,24 @@ impl From for ScalarValue { } } -impl From> for ScalarValue { - fn from(value: Vec<(&str, ScalarValue)>) -> Self { - let (fields, scalars): (SchemaBuilder, Vec<_>) = value - .into_iter() - .map(|(name, scalar)| (Field::new(name, scalar.data_type(), false), scalar)) - .unzip(); +// TODO: Remove this after changing to Scalar +// Wrapper for ScalarValue::Struct that checks the length of the arrays, without nulls +impl From<(Fields, Vec)> for ScalarValue { + fn from((fields, arrays): (Fields, Vec)) -> Self { + Self::from((fields, arrays, None)) + } +} - Self::Struct(Some(scalars), fields.finish().fields) +// TODO: Remove this after changing to Scalar +// Wrapper for ScalarValue::Struct that checks the length of the arrays +impl From<(Fields, Vec, Option)> for ScalarValue { + fn from( + (fields, arrays, nulls): (Fields, Vec, Option), + ) -> Self { + for arr in arrays.iter() { + assert_eq!(arr.len(), 1); + } + Self::Struct(Arc::new(StructArray::new(fields, arrays, nulls))) } } @@ -2941,7 +2981,6 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { ScalarValue::IntervalMonthDayNano(None) } - DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), DataType::Duration(TimeUnit::Millisecond) => { ScalarValue::DurationMillisecond(None) @@ -2952,7 +2991,6 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { ScalarValue::DurationNanosecond(None) } - DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( index_type.clone(), Box::new(value_type.as_ref().try_into()?), @@ -2998,7 +3036,12 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), - DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), + DataType::Struct(fields) => ScalarValue::Struct( + new_null_array(&DataType::Struct(fields.to_owned()), 1) + .as_struct() + .to_owned() + .into(), + ), DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( @@ -3078,18 +3121,38 @@ impl fmt::Display for ScalarValue { ScalarValue::DurationMillisecond(e) => format_option!(f, e)?, ScalarValue::DurationMicrosecond(e) => format_option!(f, e)?, ScalarValue::DurationNanosecond(e) => format_option!(f, e)?, - ScalarValue::Struct(e, fields) => match e { - Some(l) => write!( + ScalarValue::Struct(struct_arr) => { + // ScalarValue Struct should always have a single element + assert_eq!(struct_arr.len(), 1); + + let columns = struct_arr.columns(); + let fields = struct_arr.fields(); + let nulls = struct_arr.nulls(); + + write!( f, "{{{}}}", - l.iter() + columns + .iter() .zip(fields.iter()) - .map(|(value, field)| format!("{}:{}", field.name(), value)) + .enumerate() + .map(|(index, (column, field))| { + if nulls.is_some_and(|b| b.is_null(index)) { + format!("{}:NULL", field.name()) + } else if let DataType::Struct(_) = field.data_type() { + let sv = ScalarValue::Struct(Arc::new( + column.as_struct().to_owned(), + )); + format!("{}:{sv}", field.name()) + } else { + let sv = array_value_to_string(column, 0).unwrap(); + format!("{}:{sv}", field.name()) + } + }) .collect::>() .join(",") - )?, - None => write!(f, "NULL")?, - }, + )? + } ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; @@ -3152,6 +3215,28 @@ impl fmt::Debug for ScalarValue { ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), + ScalarValue::Struct(struct_arr) => { + // ScalarValue Struct should always have a single element + assert_eq!(struct_arr.len(), 1); + + let columns = struct_arr.columns(); + let fields = struct_arr.fields(); + + write!( + f, + "Struct({{{}}})", + columns + .iter() + .zip(fields.iter()) + .map(|(column, field)| { + let sv = array_value_to_string(column, 0).unwrap(); + let name = field.name(); + format!("{name}:{sv}") + }) + .collect::>() + .join(",") + ) + } ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3183,21 +3268,6 @@ impl fmt::Debug for ScalarValue { ScalarValue::DurationNanosecond(_) => { write!(f, "DurationNanosecond(\"{self}\")") } - ScalarValue::Struct(e, fields) => { - // Use Debug representation of field values - match e { - Some(l) => write!( - f, - "Struct({{{}}})", - l.iter() - .zip(fields.iter()) - .map(|(value, field)| format!("{}:{:?}", field.name(), value)) - .collect::>() - .join(",") - ), - None => write!(f, "Struct(NULL)"), - } - } ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } @@ -3246,18 +3316,100 @@ mod tests { use std::sync::Arc; use super::*; - use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; + use crate::cast::{ + as_string_array, as_struct_array, as_uint32_array, as_uint64_array, + }; use arrow::buffer::OffsetBuffer; - use arrow::compute::{concat, is_null, kernels}; + use arrow::compute::{is_null, kernels}; use arrow::datatypes::{ArrowNumericType, ArrowPrimitiveType}; use arrow::util::pretty::pretty_format_columns; - + use arrow_buffer::Buffer; use chrono::NaiveDate; use rand::Rng; #[test] - fn test_to_array_of_size_for_list() { + fn test_scalar_value_from_for_struct() { + let boolean = Arc::new(BooleanArray::from(vec![false])); + let int = Arc::new(Int32Array::from(vec![42])); + + let expected = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + + let arrays = vec![boolean as ArrayRef, int as ArrayRef]; + let fields = Fields::from(vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]); + let sv = ScalarValue::from((fields, arrays)); + let struct_arr = sv.to_array().unwrap(); + let actual = as_struct_array(&struct_arr).unwrap(); + assert_eq!(actual, &expected); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn test_scalar_value_from_for_struct_should_panic() { + let fields = Fields::from(vec![ + Field::new("bool", DataType::Boolean, false), + Field::new("i32", DataType::Int32, false), + ]); + + let arrays = vec![ + Arc::new(BooleanArray::from(vec![false, true, false, false])) as ArrayRef, + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), + ]; + + let _ = ScalarValue::from((fields, arrays)); + } + + #[test] + fn test_to_array_of_size_for_nested() { + // Struct + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + let sv = ScalarValue::Struct(Arc::new(struct_array)); + let actual_arr = sv.to_array_of_size(2).unwrap(); + + let boolean = Arc::new(BooleanArray::from(vec![ + false, false, true, true, false, false, true, true, + ])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31, 42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + + let actual = as_struct_array(&actual_arr).unwrap(); + assert_eq!(actual, &struct_array); + + // List let arr = ListArray::from_iter_primitive::(vec![Some(vec![ Some(1), None, @@ -3268,7 +3420,7 @@ mod tests { let actual_arr = sv .to_array_of_size(2) .expect("Failed to convert to array of size"); - let actual_list_arr = as_list_array(&actual_arr); + let actual_list_arr = actual_arr.as_list::(); let arr = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), None, Some(2)]), @@ -3358,6 +3510,94 @@ mod tests { .collect() } + #[test] + fn test_iter_to_array_struct() { + let s1 = StructArray::from(vec![ + ( + Arc::new(Field::new("A", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false])) as ArrayRef, + ), + ( + Arc::new(Field::new("B", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![42])) as ArrayRef, + ), + ]); + + let s2 = StructArray::from(vec![ + ( + Arc::new(Field::new("A", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false])) as ArrayRef, + ), + ( + Arc::new(Field::new("B", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![42])) as ArrayRef, + ), + ]); + + let scalars = vec![ + ScalarValue::Struct(Arc::new(s1)), + ScalarValue::Struct(Arc::new(s2)), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + + let expected = StructArray::from(vec![ + ( + Arc::new(Field::new("A", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, + ), + ( + Arc::new(Field::new("B", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![42, 42])) as ArrayRef, + ), + ]); + assert_eq!(array.as_ref(), &expected); + } + + #[test] + fn test_iter_to_array_struct_with_nulls() { + // non-null + let s1 = StructArray::from(( + vec![ + ( + Arc::new(Field::new("A", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![1])) as ArrayRef, + ), + ( + Arc::new(Field::new("B", DataType::Int64, false)), + Arc::new(Int64Array::from(vec![2])) as ArrayRef, + ), + ], + // Present the null mask, 1 is non-null, 0 is null + Buffer::from(&[1]), + )); + + // null + let s2 = StructArray::from(( + vec![ + ( + Arc::new(Field::new("A", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![3])) as ArrayRef, + ), + ( + Arc::new(Field::new("B", DataType::Int64, false)), + Arc::new(Int64Array::from(vec![4])) as ArrayRef, + ), + ], + Buffer::from(&[0]), + )); + + let scalars = vec![ + ScalarValue::Struct(Arc::new(s1)), + ScalarValue::Struct(Arc::new(s2)), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let struct_array = array.as_struct(); + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_null(1)); + } + #[test] fn iter_to_array_primitive_test() { // List[[1,2,3]], List[null], List[[4,5]] @@ -3406,7 +3646,7 @@ mod tests { ]; let array = ScalarValue::iter_to_array(scalars).unwrap(); - let result = as_list_array(&array); + let result = array.as_list::(); // build expected array let string_builder = StringBuilder::with_capacity(5, 25); @@ -4473,40 +4713,34 @@ mod tests { false, )); - let scalar = ScalarValue::Struct( - Some(vec![ - ScalarValue::Int32(Some(23)), - ScalarValue::Boolean(Some(false)), - ScalarValue::from("Hello"), - ScalarValue::from(vec![ - ("e", ScalarValue::from(2i16)), - ("f", ScalarValue::from(3i64)), - ]), - ]), - vec![ + let struct_array = StructArray::from(vec![ + ( + field_e.clone(), + Arc::new(Int16Array::from(vec![2])) as ArrayRef, + ), + ( + field_f.clone(), + Arc::new(Int64Array::from(vec![3])) as ArrayRef, + ), + ]); + + let struct_array = StructArray::from(vec![ + ( field_a.clone(), + Arc::new(Int32Array::from(vec![23])) as ArrayRef, + ), + ( field_b.clone(), + Arc::new(BooleanArray::from(vec![false])) as ArrayRef, + ), + ( field_c.clone(), - field_d.clone(), - ] - .into(), - ); - - // Check Display - assert_eq!( - format!("{scalar}"), - String::from("{A:23,B:false,C:Hello,D:{e:2,f:3}}") - ); - - // Check Debug - assert_eq!( - format!("{scalar:?}"), - String::from( - r#"Struct({A:Int32(23),B:Boolean(false),C:Utf8("Hello"),D:Struct({e:Int16(2),f:Int64(3)})})"# - ) - ); + Arc::new(StringArray::from(vec!["Hello"])) as ArrayRef, + ), + (field_d.clone(), Arc::new(struct_array) as ArrayRef), + ]); + let scalar = ScalarValue::Struct(Arc::new(struct_array)); - // Convert to length-2 array let array = scalar .to_array_of_size(2) .expect("Failed to convert to array of size"); @@ -4548,7 +4782,10 @@ mod tests { // None version let none_scalar = ScalarValue::try_from(array.data_type()).unwrap(); assert!(none_scalar.is_null()); - assert_eq!(format!("{none_scalar:?}"), String::from("Struct(NULL)")); + assert_eq!( + format!("{none_scalar:?}"), + String::from("Struct({A:,B:,C:,D:})") + ); // Construct with convenience From> let constructed = ScalarValue::from(vec![ @@ -4608,26 +4845,26 @@ mod tests { let expected = Arc::new(StructArray::from(vec![ ( - field_a, + field_a.clone(), Arc::new(Int32Array::from(vec![23, 7, -1000])) as ArrayRef, ), ( - field_b, + field_b.clone(), Arc::new(BooleanArray::from(vec![false, true, true])) as ArrayRef, ), ( - field_c, + field_c.clone(), Arc::new(StringArray::from(vec!["Hello", "World", "!!!!!"])) as ArrayRef, ), ( - field_d, + field_d.clone(), Arc::new(StructArray::from(vec![ ( - field_e, + field_e.clone(), Arc::new(Int16Array::from(vec![2, 4, 6])) as ArrayRef, ), ( - field_f, + field_f.clone(), Arc::new(Int64Array::from(vec![3, 5, 7])) as ArrayRef, ), ])) as ArrayRef, @@ -4687,6 +4924,7 @@ mod tests { // iter_to_array for struct scalars let array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); + let array = as_struct_array(&array).unwrap(); let expected = StructArray::from(vec![ ( @@ -4718,7 +4956,7 @@ mod tests { // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); - let array = as_list_array(&array); + let array = array.as_list::(); // Construct expected array with array builders let field_a_builder = StringBuilder::with_capacity(4, 1024); @@ -4732,6 +4970,7 @@ mod tests { Box::new(field_primitive_list_builder), ], ); + let mut list_builder = ListBuilder::new(element_builder); list_builder @@ -4865,7 +5104,7 @@ mod tests { ScalarValue::List(Arc::new(arr3)), ]) .unwrap(); - let array = as_list_array(&array); + let array = array.as_list::(); // Construct expected array with array builders let inner_builder = Int32Array::builder(6); @@ -5334,43 +5573,71 @@ mod tests { Field::new("a", DataType::UInt64, true), Field::new("b", DataType::Struct(fields_b.clone()), true), ]); - let scalars = vec![ - ScalarValue::Struct(None, fields.clone()), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(None), - ScalarValue::Struct(None, fields_b.clone()), - ]), - fields.clone(), + + let struct_value = vec![ + ( + fields[0].clone(), + Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, ), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(None), - ScalarValue::Struct( - Some(vec![ScalarValue::UInt64(None), ScalarValue::UInt64(None)]), - fields_b.clone(), + ( + fields[1].clone(), + Arc::new(StructArray::from(vec![ + ( + fields_b[0].clone(), + Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, ), - ]), - fields.clone(), - ), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(Some(1)), - ScalarValue::Struct( - Some(vec![ - ScalarValue::UInt64(Some(2)), - ScalarValue::UInt64(Some(3)), - ]), - fields_b, + ( + fields_b[1].clone(), + Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, ), - ]), - fields, + ])) as ArrayRef, ), ]; + let struct_value_with_nulls = vec![ + ( + fields[0].clone(), + Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, + ), + ( + fields[1].clone(), + Arc::new(StructArray::from(( + vec![ + ( + fields_b[0].clone(), + Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, + ), + ( + fields_b[1].clone(), + Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, + ), + ], + Buffer::from(&[0]), + ))) as ArrayRef, + ), + ]; + + let scalars = vec![ + // all null + ScalarValue::Struct(Arc::new(StructArray::from(( + struct_value.clone(), + Buffer::from(&[0]), + )))), + // field 1 valid, field 2 null + ScalarValue::Struct(Arc::new(StructArray::from(( + struct_value_with_nulls.clone(), + Buffer::from(&[1]), + )))), + // all valid + ScalarValue::Struct(Arc::new(StructArray::from(( + struct_value.clone(), + Buffer::from(&[1]), + )))), + ]; + let check_array = |array| { let is_null = is_null(&array).unwrap(); - assert_eq!(is_null, BooleanArray::from(vec![true, false, false, false])); + assert_eq!(is_null, BooleanArray::from(vec![true, false, false])); let formatted = pretty_format_columns("col", &[array]).unwrap().to_string(); let formatted = formatted.split('\n').collect::>(); @@ -5379,8 +5646,7 @@ mod tests { "| col |", "+---------------------------+", "| |", - "| {a: , b: } |", - "| {a: , b: {ba: , bb: }} |", + "| {a: 1, b: } |", "| {a: 1, b: {ba: 2, bb: 3}} |", "+---------------------------+", ]; @@ -5401,7 +5667,7 @@ mod tests { .collect::>>() .expect("Failed to convert to array"); let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); - let array = concat(&arrays).unwrap(); + let array = arrow::compute::concat(&arrays).unwrap(); check_array(array); } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 591a3e3131c8..71be8ec7e879 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1985,7 +1985,6 @@ mod tests { use crate::physical_plan::{DisplayAs, SendableRecordBatchStream}; use crate::physical_planner::PhysicalPlanner; use crate::prelude::{SessionConfig, SessionContext}; - use crate::scalar::ScalarValue; use crate::test_util::{scan_empty, scan_empty_with_partitions}; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type, SchemaRef}; @@ -2310,10 +2309,11 @@ mod tests { /// Return a `null` literal representing a struct type like: `{ a: bool }` fn struct_literal() -> Expr { - let struct_literal = ScalarValue::Struct( - None, + let struct_literal = ScalarValue::try_from(DataType::Struct( vec![Field::new("foo", DataType::Boolean, false)].into(), - ); + )) + .unwrap(); + lit(struct_literal) } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5dbac0322fc0..0b29ad10d670 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -19,7 +19,7 @@ //! user defined aggregate functions use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray}; +use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; use arrow_schema::Schema; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -582,36 +582,29 @@ impl FirstSelector { // Internally, keep the data types as this type fn state_datatypes() -> Vec { - vec![ - DataType::Float64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ] + vec![Self::output_datatype()] } /// Convert to a set of ScalarValues - fn to_state(&self) -> Vec { - vec![ - ScalarValue::Float64(Some(self.value)), - ScalarValue::TimestampNanosecond(Some(self.time), None), - ] - } + fn to_state(&self) -> Result { + let f64arr = Arc::new(Float64Array::from(vec![self.value])) as ArrayRef; + let timearr = + Arc::new(TimestampNanosecondArray::from(vec![self.time])) as ArrayRef; - /// return this selector as a single scalar (struct) value - fn to_scalar(&self) -> ScalarValue { - ScalarValue::Struct(Some(self.to_state()), Self::fields()) + let struct_arr = + StructArray::try_new(Self::fields(), vec![f64arr, timearr], None)?; + Ok(ScalarValue::Struct(Arc::new(struct_arr))) } } impl Accumulator for FirstSelector { fn state(&mut self) -> Result> { - let state = self.to_state().into_iter().collect::>(); - - Ok(state) + self.evaluate().map(|s| vec![s]) } /// produce the output structure fn evaluate(&mut self) -> Result { - Ok(self.to_scalar()) + self.to_state() } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index a58856e398e3..2e9df477d516 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -17,17 +17,18 @@ //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` -use arrow::datatypes::{DataType, Field}; use std::any::Any; +use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; use arrow::array::ArrayRef; -use std::collections::HashSet; +use arrow::datatypes::{DataType, Field}; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -137,10 +138,11 @@ impl Accumulator for DistinctArrayAggAccumulator { assert_eq!(values.len(), 1, "batch input should only include 1 column!"); let array = &values[0]; - let scalars = ScalarValue::convert_array_to_scalar_vec(array)?; - for scalar in scalars { - self.values.extend(scalar) + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(array)?; + for scalars in scalar_vec { + self.values.extend(scalars); } + Ok(()) } @@ -149,18 +151,7 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - assert_eq!( - states.len(), - 1, - "array_agg_distinct states must contain single array" - ); - - let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; - for scalars in scalar_vec { - self.values.extend(scalars) - } - - Ok(()) + self.update_batch(states) } fn evaluate(&mut self) -> Result { @@ -187,7 +178,8 @@ mod tests { use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::types::Int32Type; - use arrow_array::{Array, ListArray}; + use arrow_array::Array; + use arrow_array::ListArray; use arrow_buffer::OffsetBuffer; use datafusion_common::utils::array_into_list_array; use datafusion_common::{internal_err, DataFusionError}; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 5263fa83a6eb..587f40081c90 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -33,7 +33,10 @@ use crate::{ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; +use arrow_array::{new_empty_array, StructArray}; use arrow_schema::{Fields, SortOptions}; + +use datafusion_common::utils::array_into_list_array; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -219,6 +222,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if states.is_empty() { return Ok(()); } + // First entry in the state is the aggregation result. Second entry // stores values received for ordering requirement columns for each // aggregation value inside `ARRAY_AGG` list. For each `StructArray` @@ -241,29 +245,35 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { partition_values.push(self.values.clone().into()); partition_ordering_values.push(self.ordering_values.clone().into()); + // Convert array to Scalars to sort them easily. Convert back to array at evaluation. let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - for v in array_agg_res.into_iter() { partition_values.push(v.into()); } let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { + for partition_ordering_rows in orderings.into_iter() { // Extract value from struct to ordering_rows for each group/partition - partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", - ordering_row.data_type() - ) - } - }).collect::>>() - }).collect::>>()?; - for ordering_values in ordering_values.into_iter() { - partition_ordering_values.push(ordering_values); + let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(s) = ordering_row { + let mut ordering_columns_per_row = vec![]; + + for column in s.columns() { + let sv = ScalarValue::try_from_array(column, 0)?; + ordering_columns_per_row.push(sv); + } + + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Arc) but got:{:?}", + ordering_row.data_type() + ) + } + }).collect::>>()?; + + partition_ordering_values.push(ordering_value); } let sort_options = self @@ -271,11 +281,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { .iter() .map(|sort_expr| sort_expr.options) .collect::>(); + (self.values, self.ordering_values) = merge_ordered_arrays( &mut partition_values, &mut partition_ordering_values, &sort_options, )?; + Ok(()) } @@ -323,20 +335,32 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); - let struct_field = Fields::from(fields); - - let orderings: Vec = self - .ordering_values - .iter() - .map(|ordering| { - ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) - }) - .collect(); - let struct_type = DataType::Struct(struct_field); + let num_columns = fields.len(); + let struct_field = Fields::from(fields.clone()); + + let mut column_wise_ordering_values = vec![]; + for i in 0..num_columns { + let column_values = self + .ordering_values + .iter() + .map(|x| x[i].clone()) + .collect::>(); + let array = if column_values.is_empty() { + new_empty_array(fields[i].data_type()) + } else { + ScalarValue::iter_to_array(column_values.into_iter())? + }; + column_wise_ordering_values.push(array); + } - // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases - let arr = ScalarValue::new_list(&orderings, &struct_type); - Ok(ScalarValue::List(arr)) + let ordering_array = StructArray::try_new( + struct_field.clone(), + column_wise_ordering_values, + None, + )?; + Ok(ScalarValue::List(Arc::new(array_into_list_array( + Arc::new(ordering_array), + )))) } } diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs index 368bd353a4ed..5d721e3a5e87 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -30,9 +30,9 @@ use crate::{ }; use arrow_array::cast::AsArray; -use arrow_array::ArrayRef; +use arrow_array::{new_empty_array, ArrayRef, StructArray}; use arrow_schema::{DataType, Field, Fields}; -use datafusion_common::utils::get_row_at_idx; +use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -271,7 +271,14 @@ impl Accumulator for NthValueAccumulator { let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { // Extract value from struct to ordering_rows for each group/partition partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + if let ScalarValue::Struct(s) = ordering_row { + let mut ordering_columns_per_row = vec![]; + + for column in s.columns() { + let sv = ScalarValue::try_from_array(column, 0)?; + ordering_columns_per_row.push(sv); + } + Ok(ordering_columns_per_row) } else { exec_err!( @@ -306,7 +313,7 @@ impl Accumulator for NthValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.evaluate_values()]; if !self.ordering_req.is_empty() { - result.push(self.evaluate_orderings()); + result.push(self.evaluate_orderings()?); } Ok(result) } @@ -355,21 +362,35 @@ impl Accumulator for NthValueAccumulator { } impl NthValueAccumulator { - fn evaluate_orderings(&self) -> ScalarValue { + fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); - let struct_field = Fields::from(fields); + let struct_field = Fields::from(fields.clone()); - let orderings = self - .ordering_values - .iter() - .map(|ordering| { - ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) - }) - .collect::>(); - let struct_type = DataType::Struct(struct_field); + let mut column_wise_ordering_values = vec![]; + let num_columns = fields.len(); + for i in 0..num_columns { + let column_values = self + .ordering_values + .iter() + .map(|x| x[i].clone()) + .collect::>(); + let array = if column_values.is_empty() { + new_empty_array(fields[i].data_type()) + } else { + ScalarValue::iter_to_array(column_values.into_iter())? + }; + column_wise_ordering_values.push(array); + } + + let ordering_array = StructArray::try_new( + struct_field.clone(), + column_wise_ordering_values, + None, + )?; - // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases - ScalarValue::List(ScalarValue::new_list(&orderings, &struct_type)) + Ok(ScalarValue::List(Arc::new(array_into_list_array( + Arc::new(ordering_array), + )))) } fn evaluate_values(&self) -> ScalarValue { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 218399694884..667e53842e56 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -944,7 +944,8 @@ message Union{ repeated int32 type_ids = 3; } -message ScalarListValue { +// Used for List/FixedSizeList/LargeList/Struct +message ScalarNestedValue { bytes ipc_message = 1; bytes arrow_data = 2; Schema schema = 3; @@ -985,14 +986,6 @@ message IntervalMonthDayNanoValue { int64 nanos = 3; } -message StructValue { - // Note that a null struct value must have one or more fields, so we - // encode a null StructValue as one witth an empty field_values - // list. - repeated ScalarValue field_values = 2; - repeated Field fields = 3; -} - message ScalarFixedSizeBinary{ bytes values = 1; int32 length = 2; @@ -1023,9 +1016,10 @@ message ScalarValue{ // Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarTime32Value time32_value = 15; - ScalarListValue large_list_value = 16; - ScalarListValue list_value = 17; - ScalarListValue fixed_size_list_value = 18; + ScalarNestedValue large_list_value = 16; + ScalarNestedValue list_value = 17; + ScalarNestedValue fixed_size_list_value = 18; + ScalarNestedValue struct_value = 32; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; @@ -1045,7 +1039,6 @@ message ScalarValue{ bytes large_binary_value = 29; ScalarTime64Value time64_value = 30; IntervalMonthDayNanoValue interval_month_day_nano = 31; - StructValue struct_value = 32; ScalarFixedSizeBinary fixed_size_binary_value = 34; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 450b18dc0982..5b7d27d0dff0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22870,7 +22870,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunctionNode { deserializer.deserialize_struct("datafusion.ScalarFunctionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarListValue { +impl serde::Serialize for ScalarNestedValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22887,7 +22887,7 @@ impl serde::Serialize for ScalarListValue { if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarListValue", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarNestedValue", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; @@ -22902,7 +22902,7 @@ impl serde::Serialize for ScalarListValue { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarListValue { +impl<'de> serde::Deserialize<'de> for ScalarNestedValue { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -22954,13 +22954,13 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarListValue; + type Value = ScalarNestedValue; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarListValue") + formatter.write_str("struct datafusion.ScalarNestedValue") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -22993,14 +22993,14 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { } } } - Ok(ScalarListValue { + Ok(ScalarNestedValue { ipc_message: ipc_message__.unwrap_or_default(), arrow_data: arrow_data__.unwrap_or_default(), schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.ScalarListValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ScalarNestedValue", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ScalarTime32Value { @@ -23561,6 +23561,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::FixedSizeListValue(v) => { struct_ser.serialize_field("fixedSizeListValue", v)?; } + scalar_value::Value::StructValue(v) => { + struct_ser.serialize_field("structValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -23614,9 +23617,6 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::IntervalMonthDayNano(v) => { struct_ser.serialize_field("intervalMonthDayNano", v)?; } - scalar_value::Value::StructValue(v) => { - struct_ser.serialize_field("structValue", v)?; - } scalar_value::Value::FixedSizeBinaryValue(v) => { struct_ser.serialize_field("fixedSizeBinaryValue", v)?; } @@ -23670,6 +23670,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "listValue", "fixed_size_list_value", "fixedSizeListValue", + "struct_value", + "structValue", "decimal128_value", "decimal128Value", "decimal256_value", @@ -23700,8 +23702,6 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time64Value", "interval_month_day_nano", "intervalMonthDayNano", - "struct_value", - "structValue", "fixed_size_binary_value", "fixedSizeBinaryValue", ]; @@ -23727,6 +23727,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { LargeListValue, ListValue, FixedSizeListValue, + StructValue, Decimal128Value, Decimal256Value, Date64Value, @@ -23742,7 +23743,6 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { LargeBinaryValue, Time64Value, IntervalMonthDayNano, - StructValue, FixedSizeBinaryValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -23784,6 +23784,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), "listValue" | "list_value" => Ok(GeneratedField::ListValue), "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), + "structValue" | "struct_value" => Ok(GeneratedField::StructValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -23799,7 +23800,6 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "largeBinaryValue" | "large_binary_value" => Ok(GeneratedField::LargeBinaryValue), "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), - "structValue" | "struct_value" => Ok(GeneratedField::StructValue), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -23940,6 +23940,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) +; + } + GeneratedField::StructValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("structValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) ; } GeneratedField::Decimal128Value => { @@ -24036,13 +24043,6 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("intervalMonthDayNano")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::IntervalMonthDayNano) -; - } - GeneratedField::StructValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("structValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) ; } GeneratedField::FixedSizeBinaryValue => { @@ -25535,115 +25535,6 @@ impl<'de> serde::Deserialize<'de> for Struct { deserializer.deserialize_struct("datafusion.Struct", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for StructValue { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.field_values.is_empty() { - len += 1; - } - if !self.fields.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.StructValue", len)?; - if !self.field_values.is_empty() { - struct_ser.serialize_field("fieldValues", &self.field_values)?; - } - if !self.fields.is_empty() { - struct_ser.serialize_field("fields", &self.fields)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for StructValue { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "field_values", - "fieldValues", - "fields", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - FieldValues, - Fields, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "fieldValues" | "field_values" => Ok(GeneratedField::FieldValues), - "fields" => Ok(GeneratedField::Fields), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = StructValue; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.StructValue") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut field_values__ = None; - let mut fields__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::FieldValues => { - if field_values__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldValues")); - } - field_values__ = Some(map_.next_value()?); - } - GeneratedField::Fields => { - if fields__.is_some() { - return Err(serde::de::Error::duplicate_field("fields")); - } - fields__ = Some(map_.next_value()?); - } - } - } - Ok(StructValue { - field_values: field_values__.unwrap_or_default(), - fields: fields__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.StructValue", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for SubqueryAliasNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7894285129f6..cdf4dadcf894 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1149,9 +1149,10 @@ pub struct Union { #[prost(int32, repeated, tag = "3")] pub type_ids: ::prost::alloc::vec::Vec, } +/// Used for List/FixedSizeList/LargeList/Struct #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarListValue { +pub struct ScalarNestedValue { #[prost(bytes = "vec", tag = "1")] pub ipc_message: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", tag = "2")] @@ -1236,17 +1237,6 @@ pub struct IntervalMonthDayNanoValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct StructValue { - /// Note that a null struct value must have one or more fields, so we - /// encode a null StructValue as one witth an empty field_values - /// list. - #[prost(message, repeated, tag = "2")] - pub field_values: ::prost::alloc::vec::Vec, - #[prost(message, repeated, tag = "3")] - pub fields: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -1258,7 +1248,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34" )] pub value: ::core::option::Option, } @@ -1303,11 +1293,13 @@ pub mod scalar_value { #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), #[prost(message, tag = "16")] - LargeListValue(super::ScalarListValue), + LargeListValue(super::ScalarNestedValue), #[prost(message, tag = "17")] - ListValue(super::ScalarListValue), + ListValue(super::ScalarNestedValue), #[prost(message, tag = "18")] - FixedSizeListValue(super::ScalarListValue), + FixedSizeListValue(super::ScalarNestedValue), + #[prost(message, tag = "32")] + StructValue(super::ScalarNestedValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -1338,8 +1330,6 @@ pub mod scalar_value { Time64Value(super::ScalarTime64Value), #[prost(message, tag = "31")] IntervalMonthDayNano(super::IntervalMonthDayNanoValue), - #[prost(message, tag = "32")] - StructValue(super::StructValue), #[prost(message, tag = "34")] FixedSizeBinaryValue(super::ScalarFixedSizeBinary), } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 8ef7271ff2a5..0689da803538 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -686,14 +686,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), // ScalarValue::List is serialized using arrow IPC format - Value::ListValue(scalar_list) - | Value::FixedSizeListValue(scalar_list) - | Value::LargeListValue(scalar_list) => { - let protobuf::ScalarListValue { + Value::ListValue(v) + | Value::FixedSizeListValue(v) + | Value::LargeListValue(v) + | Value::StructValue(v) => { + let protobuf::ScalarNestedValue { ipc_message, arrow_data, schema, - } = &scalar_list; + } = &v; let schema: Schema = if let Some(schema_ref) = schema { schema_ref.try_into()? @@ -739,6 +740,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::FixedSizeListValue(_) => { Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) } + Value::StructValue(_) => { + Self::Struct(arr.as_struct().to_owned().into()) + } _ => unreachable!(), } } @@ -839,28 +843,6 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), )), - Value::StructValue(v) => { - // all structs must have at least 1 field, so we treat - // an empty values list as NULL - let values = if v.field_values.is_empty() { - None - } else { - Some( - v.field_values - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - ) - }; - - let fields = v - .fields - .iter() - .map(Field::try_from) - .collect::>()?; - - Self::Struct(values, fields) - } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e5948de40a23..4df7f9fb6bf3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1178,17 +1178,17 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } - // ScalarValue::List and ScalarValue::FixedSizeList are serialized using - // Arrow IPC messages as a single column RecordBatch ScalarValue::List(arr) => { - encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } ScalarValue::LargeList(arr) => { - // Wrap in a "field_name" column - encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } ScalarValue::FixedSizeList(arr) => { - encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::Struct(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) @@ -1400,34 +1400,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }; Ok(protobuf::ScalarValue { value: Some(value) }) } - - ScalarValue::Struct(values, fields) => { - // encode null as empty field values list - let field_values = if let Some(values) = values { - if values.is_empty() { - return Err(Error::InvalidScalarValue(val.clone())); - } - values - .iter() - .map(|v| v.try_into()) - .collect::, _>>()? - } else { - vec![] - }; - - let fields = fields - .iter() - .map(|f| f.as_ref().try_into()) - .collect::, _>>()?; - - Ok(protobuf::ScalarValue { - value: Some(Value::StructValue(protobuf::StructValue { - field_values, - fields, - })), - }) - } - ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { @@ -1709,7 +1681,9 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -fn encode_scalar_list_value( +// ScalarValue::List / FixedSizeList / LargeList / Struct are serialized using +// Arrow IPC messages as a single column RecordBatch +fn encode_scalar_nested_value( arr: ArrayRef, val: &ScalarValue, ) -> Result { @@ -1729,7 +1703,7 @@ fn encode_scalar_list_value( let schema: protobuf::Schema = batch.schema().try_into()?; - let scalar_list_value = protobuf::ScalarListValue { + let scalar_list_value = protobuf::ScalarNestedValue { ipc_message: encoded_message.ipc_message, arrow_data: encoded_message.arrow_data, schema: Some(schema), @@ -1749,6 +1723,11 @@ fn encode_scalar_list_value( scalar_list_value, )), }), + ScalarValue::Struct(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::StructValue( + scalar_list_value, + )), + }), _ => unreachable!(), } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1bcdffe89236..652e59672bc7 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -21,6 +21,7 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::array::{BooleanArray, Int32Array}; use arrow::csv::WriterBuilder; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, @@ -61,8 +62,8 @@ use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; +use datafusion_proto::logical_plan::from_proto; use datafusion_proto::logical_plan::LogicalExtensionCodec; -use datafusion_proto::logical_plan::{from_proto, to_proto}; use datafusion_proto::protobuf; #[cfg(feature = "json")] @@ -746,32 +747,6 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } } -#[test] -fn scalar_values_error_serialization() { - let should_fail_on_seralize: Vec = vec![ - // Should fail due to empty values - ScalarValue::Struct( - Some(vec![]), - vec![Field::new("item", DataType::Int16, true)].into(), - ), - ]; - - for test_case in should_fail_on_seralize.into_iter() { - let proto: Result = - (&test_case).try_into(); - - // Validation is also done on read, so if serialization passed - // also try to convert back to ScalarValue - if let Ok(proto) = proto { - let res: Result = (&proto).try_into(); - assert!( - res.is_err(), - "The value {test_case:?} unexpectedly serialized without error:{res:?}" - ); - } - } -} - #[test] fn round_trip_scalar_values() { let should_pass: Vec = vec![ @@ -955,23 +930,22 @@ fn round_trip_scalar_values() { ScalarValue::Binary(None), ScalarValue::LargeBinary(Some(b"bar".to_vec())), ScalarValue::LargeBinary(None), - ScalarValue::Struct( - Some(vec![ - ScalarValue::Int32(Some(23)), - ScalarValue::Boolean(Some(false)), - ]), - Fields::from(vec![ + ScalarValue::from(( + vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Boolean, false), - ]), - ), - ScalarValue::Struct( - None, - Fields::from(vec![ - Field::new("a", DataType::Int32, true), - Field::new("a", DataType::Boolean, false), - ]), - ), + ] + .into(), + vec![ + Arc::new(Int32Array::from(vec![Some(23)])) as ArrayRef, + Arc::new(BooleanArray::from(vec![Some(false)])) as ArrayRef, + ], + )), + ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, false), + ]))) + .unwrap(), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), ScalarValue::FixedSizeBinary(0, None), ScalarValue::FixedSizeBinary(5, None), diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 2a39e3138869..4002164cc918 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -380,3 +380,15 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] --PlaceholderRowExec + +# Explain Struct + +query TT +explain select struct(1, 2.3, 'abc'); +---- +logical_plan +Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc")) +--EmptyRelation +physical_plan +ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))] +--PlaceholderRowExec