diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index c8adae34f6455..010221b0485f9 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -29,8 +29,8 @@ use arrow_buffer::IntervalMonthDayNano; use crate::cast::{ as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, - as_large_list_array, as_list_array, as_primitive_array, as_string_array, - as_struct_array, + as_large_list_array, as_list_array, as_map_array, as_primitive_array, + as_string_array, as_struct_array, }; use crate::error::{Result, _internal_err}; @@ -236,6 +236,40 @@ fn hash_struct_array( Ok(()) } +fn hash_map_array( + array: &MapArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let offsets = array.offsets(); + + // Create hashes for each entry in each row + let mut values_hashes = vec![0u64; array.entries().len()]; + create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + + // Combine the hashes for entries on each row with each other and previous hash for that row + if let Some(nulls) = nulls { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + if nulls.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -400,6 +434,10 @@ pub fn create_hashes<'a>( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::Map(_, _) => { + let array = as_map_array(array)?; + hash_map_array(array, random_state, hashes_buffer)?; + } DataType::FixedSizeList(_,_) => { let array = as_fixed_size_list_array(array)?; hash_fixed_list_array(array, random_state, hashes_buffer)?; @@ -572,6 +610,7 @@ mod tests { Some(vec![Some(3), None, Some(5)]), None, Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), ]; let list_array = Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; @@ -581,6 +620,7 @@ mod tests { assert_eq!(hashes[0], hashes[5]); assert_eq!(hashes[1], hashes[4]); assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[1], hashes[6]); // null vs empty list } #[test] @@ -692,6 +732,64 @@ mod tests { assert_eq!(hashes[0], hashes[1]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_map_arrays() { + let mut builder = + MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + // Row 0 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 1 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 2 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(3); + builder.append(true).unwrap(); + // Row 3 + builder.keys().append_value("key1"); + builder.keys().append_value("key3"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 4 + builder.keys().append_value("key1"); + builder.values().append_value(1); + builder.append(true).unwrap(); + // Row 5 + builder.keys().append_value("key1"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Row 6 + builder.append(true).unwrap(); + // Row 7 + builder.keys().append_value("key1"); + builder.values().append_value(1); + builder.append(false).unwrap(); + + let array = Arc::new(builder.finish()) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); // same value + assert_ne!(hashes[0], hashes[2]); // different value + assert_ne!(hashes[0], hashes[3]); // different key + assert_ne!(hashes[0], hashes[4]); // missing an entry + assert_ne!(hashes[4], hashes[5]); // filled vs null value + assert_eq!(hashes[6], hashes[7]); // empty vs null map + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 38f70e4c1466c..80164d04918ac 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1773,6 +1773,7 @@ impl ScalarValue { } DataType::List(_) | DataType::LargeList(_) + | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) => { let arrays = scalars.map(|s| s.to_array()).collect::>>()?; @@ -1841,7 +1842,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) | DataType::ListView(_) | DataType::LargeListView(_) => { diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 26bfb4a5922e6..e530e14df66ea 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -302,3 +302,11 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), a {POST: 41, HEAD: 33, PATCH: 30} {POST: 41, HEAD: 33, PATCH: 30} {POST: 41, HEAD: 33, PATCH: 30} + + +query ? +VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1])) +---- +{a: 1} +{b: 2} +{c: 3, a: 1} diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 5768c44bbf6c8..15c447114819e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer}; use async_recursion::async_recursion; -use datafusion::arrow::array::GenericListArray; +use datafusion::arrow::array::{GenericListArray, MapArray}; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; @@ -51,6 +51,7 @@ use crate::variation_const::{ INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, }; +use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -1449,21 +1450,14 @@ fn from_substrait_type( from_substrait_type(value_type, extensions, dfs_names, name_idx)?, true, )); - match map.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), - false, - )) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - )?, - } + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { @@ -1743,11 +1737,23 @@ fn from_substrait_literal( ) } Some(LiteralType::List(l)) => { + // Each element should start the name index from the same value, then we increase it + // once at the end + let mut element_name_idx = *name_idx; let elements = l .values .iter() - .map(|el| from_substrait_literal(el, extensions, dfs_names, name_idx)) + .map(|el| { + element_name_idx = *name_idx; + from_substrait_literal( + el, + extensions, + dfs_names, + &mut element_name_idx, + ) + }) .collect::>>()?; + *name_idx = element_name_idx; if elements.is_empty() { return substrait_err!( "Empty list must be encoded as EmptyList literal type, not List" @@ -1785,6 +1791,84 @@ fn from_substrait_literal( } } } + Some(LiteralType::Map(m)) => { + // Each entry should start the name index from the same value, then we increase it + // once at the end + let mut entry_name_idx = *name_idx; + let entries = m + .key_values + .iter() + .map(|kv| { + entry_name_idx = *name_idx; + let key_sv = from_substrait_literal( + kv.key.as_ref().unwrap(), + extensions, + dfs_names, + &mut entry_name_idx, + )?; + let value_sv = from_substrait_literal( + kv.value.as_ref().unwrap(), + extensions, + dfs_names, + &mut entry_name_idx, + )?; + ScalarStructBuilder::new() + .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) + .with_scalar( + Field::new("value", value_sv.data_type(), true), + value_sv, + ) + .build() + }) + .collect::>>()?; + *name_idx = entry_name_idx; + + if entries.is_empty() { + return substrait_err!( + "Empty map must be encoded as EmptyMap literal type, not Map" + ); + } + + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(Field::new("entries", entries[0].data_type(), false)), + OffsetBuffer::new(vec![0, entries.len() as i32].into()), + ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), + None, + false, + ))) + } + Some(LiteralType::EmptyMap(m)) => { + let key = match &m.key { + Some(k) => Ok(k), + _ => plan_err!("Missing key type for empty map"), + }?; + let value = match &m.value { + Some(v) => Ok(v), + _ => plan_err!("Missing value type for empty map"), + }?; + let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?; + let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?; + + // new_empty_array on a MapType creates a too empty array + // We want it to contain an empty struct array to align with an empty MapBuilder one + let entries = Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + ); + let struct_array = + new_empty_array(entries.data_type()).as_struct().to_owned(); + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(entries), + OffsetBuffer::new(vec![0, 0].into()), + struct_array, + None, + false, + ))) + } Some(LiteralType::Struct(s)) => { let mut builder = ScalarStructBuilder::new(); for (i, field) in s.fields.iter().enumerate() { @@ -2013,6 +2097,29 @@ fn from_substrait_null( ), } } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + + let key_type = + from_substrait_type(key_type, extensions, dfs_names, name_idx)?; + let value_type = + from_substrait_type(value_type, extensions, dfs_names, name_idx)?; + let entries_field = Arc::new(Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + )); + + DataType::Map(entries_field, false /* keys sorted */).try_into() + } r#type::Kind::Struct(s) => { let fields = from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 8f69cc5e218f6..8263209ffccc7 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -57,8 +57,10 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Struct, UserDefined, + user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, Struct, + UserDefined, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; @@ -1922,6 +1924,48 @@ fn to_substrait_literal( convert_array_to_literal_list(l, extensions)?, LARGE_CONTAINER_TYPE_VARIATION_REF, ), + ScalarValue::Map(m) => { + let map = if m.is_empty() || m.value(0).is_empty() { + let mt = to_substrait_type(m.data_type(), m.is_nullable(), extensions)?; + let mt = match mt { + substrait::proto::Type { + kind: Some(r#type::Kind::Map(mt)), + } => Ok(mt.as_ref().to_owned()), + _ => exec_err!("Unexpected type for a map: {mt:?}"), + }?; + LiteralType::EmptyMap(mt) + } else { + let keys = (0..m.keys().len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&m.keys(), i)?, + extensions, + ) + }) + .collect::>>()?; + let values = (0..m.values().len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&m.values(), i)?, + extensions, + ) + }) + .collect::>>()?; + + let key_values = keys + .into_iter() + .zip(values.into_iter()) + .map(|(k, v)| { + Ok(KeyValue { + key: Some(k), + value: Some(v), + }) + }) + .collect::>>()?; + LiteralType::Map(Map { key_values }) + }; + (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) + } ScalarValue::Struct(s) => ( LiteralType::Struct(Struct { fields: s @@ -1967,7 +2011,7 @@ fn convert_array_to_literal_list( .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type( + let lt = match to_substrait_type( array.data_type(), array.is_nullable(), extensions, @@ -1977,7 +2021,7 @@ fn convert_array_to_literal_list( } => lt.as_ref().to_owned(), _ => unreachable!(), }; - Ok(LiteralType::EmptyList(et)) + Ok(LiteralType::EmptyList(lt)) } else { Ok(LiteralType::List(List { values })) } @@ -2094,7 +2138,9 @@ mod test { from_substrait_literal_without_names, from_substrait_type_without_names, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion::arrow::array::GenericListArray; + use datafusion::arrow::array::{ + GenericListArray, Int64Builder, MapBuilder, StringBuilder, + }; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; use std::collections::HashMap; @@ -2160,6 +2206,28 @@ mod test { ), )))?; + // Null map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(false)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Empty map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Valid map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.keys().append_value("key1"); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(1); + map_builder.values().append_value(2); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + let c0 = Field::new("c0", DataType::Boolean, true); let c1 = Field::new("c1", DataType::Int32, true); let c2 = Field::new("c2", DataType::Utf8, true); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 5b4389c832c7c..439e3efa29228 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -749,7 +749,7 @@ async fn roundtrip_values() -> Result<()> { [[-213.1, NULL, 5.5, 2.0, 1.0], []], \ arrow_cast([1,2,3], 'LargeList(Int64)'), \ STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \ - [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\ + [STRUCT(STRUCT('a' AS string_field) AS struct_field), STRUCT(STRUCT('b' AS string_field) AS struct_field)]\ ), \ (NULL, NULL, NULL, NULL, NULL, NULL)", "Values: \ @@ -759,7 +759,7 @@ async fn roundtrip_values() -> Result<()> { List([[-213.1, , 5.5, 2.0, 1.0], []]), \ LargeList([1, 2, 3]), \ Struct({c0:true,int_field:1,c2:}), \ - List([{struct_field: {string_field: a}}])\ + List([{struct_field: {string_field: a}}, {struct_field: {string_field: b}}])\ ), \ (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())", true).await