-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feat: support Map literals in Substrait consumer and producer #11547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0b7b0a5
f36c96b
e673c27
c7bad12
0957739
3df03ae
06b7d3c
62149fa
9b5b1b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,8 +29,8 @@ use arrow_buffer::IntervalMonthDayNano; | |
|
|
||
| use crate::cast::{ | ||
| as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, | ||
| as_large_list_array, as_list_array, as_primitive_array, as_string_array, | ||
| as_struct_array, | ||
| as_large_list_array, as_list_array, as_map_array, as_primitive_array, | ||
| as_string_array, as_struct_array, | ||
| }; | ||
| use crate::error::{Result, _internal_err}; | ||
|
|
||
|
|
@@ -236,6 +236,40 @@ fn hash_struct_array( | |
| Ok(()) | ||
| } | ||
|
|
||
| fn hash_map_array( | ||
| array: &MapArray, | ||
| random_state: &RandomState, | ||
| hashes_buffer: &mut [u64], | ||
| ) -> Result<()> { | ||
| let nulls = array.nulls(); | ||
| let offsets = array.offsets(); | ||
|
|
||
| // Create hashes for each entry in each row | ||
| let mut values_hashes = vec![0u64; array.entries().len()]; | ||
| create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; | ||
|
|
||
| // Combine the hashes for entries on each row with each other and previous hash for that row | ||
| if let Some(nulls) = nulls { | ||
| for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { | ||
| if nulls.is_valid(i) { | ||
| let hash = &mut hashes_buffer[i]; | ||
| for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { | ||
| *hash = combine_hashes(*hash, *values_hash); | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { | ||
| let hash = &mut hashes_buffer[i]; | ||
| for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { | ||
| *hash = combine_hashes(*hash, *values_hash); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn hash_list_array<OffsetSize>( | ||
| array: &GenericListArray<OffsetSize>, | ||
| random_state: &RandomState, | ||
|
|
@@ -400,6 +434,10 @@ pub fn create_hashes<'a>( | |
| let array = as_large_list_array(array)?; | ||
| hash_list_array(array, random_state, hashes_buffer)?; | ||
| } | ||
| DataType::Map(_, _) => { | ||
| let array = as_map_array(array)?; | ||
| hash_map_array(array, random_state, hashes_buffer)?; | ||
| } | ||
| DataType::FixedSizeList(_,_) => { | ||
| let array = as_fixed_size_list_array(array)?; | ||
| hash_fixed_list_array(array, random_state, hashes_buffer)?; | ||
|
|
@@ -572,6 +610,7 @@ mod tests { | |
| Some(vec![Some(3), None, Some(5)]), | ||
| None, | ||
| Some(vec![Some(0), Some(1), Some(2)]), | ||
| Some(vec![]), | ||
| ]; | ||
| let list_array = | ||
| Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(data)) as ArrayRef; | ||
|
|
@@ -581,6 +620,7 @@ mod tests { | |
| assert_eq!(hashes[0], hashes[5]); | ||
| assert_eq!(hashes[1], hashes[4]); | ||
| assert_eq!(hashes[2], hashes[3]); | ||
| assert_eq!(hashes[1], hashes[6]); // null vs empty list | ||
| } | ||
|
|
||
| #[test] | ||
|
|
@@ -692,6 +732,64 @@ mod tests { | |
| assert_eq!(hashes[0], hashes[1]); | ||
| } | ||
|
|
||
| #[test] | ||
| // Tests actual values of hashes, which are different if forcing collisions | ||
| #[cfg(not(feature = "force_hash_collisions"))] | ||
| fn create_hashes_for_map_arrays() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would help me undertand / verify this test if you could use a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done! 06b7d3c |
||
| let mut builder = | ||
| MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); | ||
| // Row 0 | ||
| builder.keys().append_value("key1"); | ||
| builder.keys().append_value("key2"); | ||
| builder.values().append_value(1); | ||
| builder.values().append_value(2); | ||
| builder.append(true).unwrap(); | ||
| // Row 1 | ||
| builder.keys().append_value("key1"); | ||
| builder.keys().append_value("key2"); | ||
| builder.values().append_value(1); | ||
| builder.values().append_value(2); | ||
| builder.append(true).unwrap(); | ||
| // Row 2 | ||
| builder.keys().append_value("key1"); | ||
| builder.keys().append_value("key2"); | ||
| builder.values().append_value(1); | ||
| builder.values().append_value(3); | ||
| builder.append(true).unwrap(); | ||
| // Row 3 | ||
| builder.keys().append_value("key1"); | ||
| builder.keys().append_value("key3"); | ||
| builder.values().append_value(1); | ||
| builder.values().append_value(2); | ||
| builder.append(true).unwrap(); | ||
| // Row 4 | ||
| builder.keys().append_value("key1"); | ||
| builder.values().append_value(1); | ||
| builder.append(true).unwrap(); | ||
| // Row 5 | ||
| builder.keys().append_value("key1"); | ||
| builder.values().append_null(); | ||
| builder.append(true).unwrap(); | ||
| // Row 6 | ||
| builder.append(true).unwrap(); | ||
| // Row 7 | ||
| builder.keys().append_value("key1"); | ||
| builder.values().append_value(1); | ||
| builder.append(false).unwrap(); | ||
|
|
||
| let array = Arc::new(builder.finish()) as ArrayRef; | ||
|
|
||
| let random_state = RandomState::with_seeds(0, 0, 0, 0); | ||
| let mut hashes = vec![0; array.len()]; | ||
| create_hashes(&[array], &random_state, &mut hashes).unwrap(); | ||
| assert_eq!(hashes[0], hashes[1]); // same value | ||
| assert_ne!(hashes[0], hashes[2]); // different value | ||
| assert_ne!(hashes[0], hashes[3]); // different key | ||
| assert_ne!(hashes[0], hashes[4]); // missing an entry | ||
| assert_ne!(hashes[4], hashes[5]); // filled vs null value | ||
| assert_eq!(hashes[6], hashes[7]); // empty vs null map | ||
| } | ||
|
|
||
| #[test] | ||
| // Tests actual values of hashes, which are different if forcing collisions | ||
| #[cfg(not(feature = "force_hash_collisions"))] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1773,6 +1773,7 @@ impl ScalarValue { | |
| } | ||
| DataType::List(_) | ||
| | DataType::LargeList(_) | ||
| | DataType::Map(_, _) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this works at least for the test case, given it re-uses arrow::compute::concat I'd hope it does the right thing overall
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is fine I think because |
||
| | DataType::Struct(_) | ||
| | DataType::Union(_, _) => { | ||
| let arrays = scalars.map(|s| s.to_array()).collect::<Result<Vec<_>>>()?; | ||
|
|
@@ -1841,7 +1842,6 @@ impl ScalarValue { | |
| | DataType::Time32(TimeUnit::Nanosecond) | ||
| | DataType::Time64(TimeUnit::Second) | ||
| | DataType::Time64(TimeUnit::Millisecond) | ||
| | DataType::Map(_, _) | ||
| | DataType::RunEndEncoded(_, _) | ||
| | DataType::ListView(_) | ||
| | DataType::LargeListView(_) => { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -302,3 +302,11 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), a | |
| {POST: 41, HEAD: 33, PATCH: 30} | ||
| {POST: 41, HEAD: 33, PATCH: 30} | ||
| {POST: 41, HEAD: 33, PATCH: 30} | ||
|
|
||
|
|
||
| query ? | ||
| VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1])) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the changes in |
||
| ---- | ||
| {a: 1} | ||
| {b: 2} | ||
| {c: 3, a: 1} | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -15,9 +15,9 @@ | |||||
| // specific language governing permissions and limitations | ||||||
| // under the License. | ||||||
|
|
||||||
| use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; | ||||||
| use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer}; | ||||||
| use async_recursion::async_recursion; | ||||||
| use datafusion::arrow::array::GenericListArray; | ||||||
| use datafusion::arrow::array::{GenericListArray, MapArray}; | ||||||
| use datafusion::arrow::datatypes::{ | ||||||
| DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, | ||||||
| }; | ||||||
|
|
@@ -51,6 +51,7 @@ use crate::variation_const::{ | |||||
| INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, | ||||||
| INTERVAL_YEAR_MONTH_TYPE_REF, | ||||||
| }; | ||||||
| use datafusion::arrow::array::{new_empty_array, AsArray}; | ||||||
| use datafusion::common::scalar::ScalarStructBuilder; | ||||||
| use datafusion::logical_expr::expr::InList; | ||||||
| use datafusion::logical_expr::{ | ||||||
|
|
@@ -1449,21 +1450,14 @@ fn from_substrait_type( | |||||
| from_substrait_type(value_type, extensions, dfs_names, name_idx)?, | ||||||
| true, | ||||||
| )); | ||||||
| match map.type_variation_reference { | ||||||
| DEFAULT_CONTAINER_TYPE_VARIATION_REF => { | ||||||
| Ok(DataType::Map( | ||||||
| Arc::new(Field::new_struct( | ||||||
| "entries", | ||||||
| [key_field, value_field], | ||||||
| false, // The inner map field is always non-nullable (Arrow #1697), | ||||||
| )), | ||||||
| false, | ||||||
| )) | ||||||
| } | ||||||
| v => not_impl_err!( | ||||||
| "Unsupported Substrait type variation {v} of type {s_kind:?}" | ||||||
| )?, | ||||||
| } | ||||||
| Ok(DataType::Map( | ||||||
| Arc::new(Field::new_struct( | ||||||
| "entries", | ||||||
| [key_field, value_field], | ||||||
| false, // The inner map field is always non-nullable (Arrow #1697), | ||||||
| )), | ||||||
| false, // whether keys are sorted | ||||||
| )) | ||||||
| } | ||||||
| r#type::Kind::Decimal(d) => match d.type_variation_reference { | ||||||
| DECIMAL_128_TYPE_VARIATION_REF => { | ||||||
|
|
@@ -1743,11 +1737,23 @@ fn from_substrait_literal( | |||||
| ) | ||||||
| } | ||||||
| Some(LiteralType::List(l)) => { | ||||||
| // Each element should start the name index from the same value, then we increase it | ||||||
| // once at the end | ||||||
| let mut element_name_idx = *name_idx; | ||||||
| let elements = l | ||||||
| .values | ||||||
| .iter() | ||||||
| .map(|el| from_substrait_literal(el, extensions, dfs_names, name_idx)) | ||||||
| .map(|el| { | ||||||
| element_name_idx = *name_idx; | ||||||
| from_substrait_literal( | ||||||
| el, | ||||||
| extensions, | ||||||
| dfs_names, | ||||||
| &mut element_name_idx, | ||||||
| ) | ||||||
| }) | ||||||
| .collect::<Result<Vec<_>>>()?; | ||||||
| *name_idx = element_name_idx; | ||||||
| if elements.is_empty() { | ||||||
| return substrait_err!( | ||||||
| "Empty list must be encoded as EmptyList literal type, not List" | ||||||
|
|
@@ -1785,6 +1791,84 @@ fn from_substrait_literal( | |||||
| } | ||||||
| } | ||||||
| } | ||||||
| Some(LiteralType::Map(m)) => { | ||||||
| // Each entry should start the name index from the same value, then we increase it | ||||||
| // once at the end | ||||||
| let mut entry_name_idx = *name_idx; | ||||||
| let entries = m | ||||||
| .key_values | ||||||
| .iter() | ||||||
| .map(|kv| { | ||||||
| entry_name_idx = *name_idx; | ||||||
| let key_sv = from_substrait_literal( | ||||||
| kv.key.as_ref().unwrap(), | ||||||
| extensions, | ||||||
| dfs_names, | ||||||
| &mut entry_name_idx, | ||||||
| )?; | ||||||
| let value_sv = from_substrait_literal( | ||||||
| kv.value.as_ref().unwrap(), | ||||||
| extensions, | ||||||
| dfs_names, | ||||||
| &mut entry_name_idx, | ||||||
| )?; | ||||||
| ScalarStructBuilder::new() | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was the most high-level way of creating the map I could think of, lmk if you have better ideas!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also/alternatively, should I add this into ScalarValue? they could sit next to ScalarValue::new_list etc
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It looks good to me. If you intend to be more efficient, I suggest referring to how
You need to partition the key and value pairs into two arrays, and build the
However, I think it just some improvements. We don't need to do that in this PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmh, yeah I think I'll leave it as-is for now, as this code is relevant only for local relations (ie data encoded in the substrait message itself) I don't expect it to be very performance-sensitive, the data scale should hopefully always be small... |
||||||
| .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) | ||||||
| .with_scalar( | ||||||
| Field::new("value", value_sv.data_type(), true), | ||||||
| value_sv, | ||||||
| ) | ||||||
| .build() | ||||||
| }) | ||||||
| .collect::<Result<Vec<_>>>()?; | ||||||
| *name_idx = entry_name_idx; | ||||||
|
|
||||||
| if entries.is_empty() { | ||||||
| return substrait_err!( | ||||||
| "Empty map must be encoded as EmptyMap literal type, not Map" | ||||||
| ); | ||||||
| } | ||||||
|
|
||||||
| ScalarValue::Map(Arc::new(MapArray::new( | ||||||
| Arc::new(Field::new("entries", entries[0].data_type(), false)), | ||||||
| OffsetBuffer::new(vec![0, entries.len() as i32].into()), | ||||||
| ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), | ||||||
| None, | ||||||
| false, | ||||||
| ))) | ||||||
| } | ||||||
| Some(LiteralType::EmptyMap(m)) => { | ||||||
| let key = match &m.key { | ||||||
| Some(k) => Ok(k), | ||||||
| _ => plan_err!("Missing key type for empty map"), | ||||||
| }?; | ||||||
| let value = match &m.value { | ||||||
| Some(v) => Ok(v), | ||||||
| _ => plan_err!("Missing value type for empty map"), | ||||||
| }?; | ||||||
| let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?; | ||||||
| let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?; | ||||||
|
|
||||||
| // new_empty_array on a MapType creates a too empty array | ||||||
| // We want it to contain an empty struct array to align with an empty MapBuilder one | ||||||
| let entries = Field::new_struct( | ||||||
| "entries", | ||||||
| vec![ | ||||||
| Field::new("key", key_type, false), | ||||||
| Field::new("value", value_type, true), | ||||||
| ], | ||||||
| false, | ||||||
| ); | ||||||
| let struct_array = | ||||||
| new_empty_array(entries.data_type()).as_struct().to_owned(); | ||||||
| ScalarValue::Map(Arc::new(MapArray::new( | ||||||
| Arc::new(entries), | ||||||
| OffsetBuffer::new(vec![0, 0].into()), | ||||||
| struct_array, | ||||||
| None, | ||||||
| false, | ||||||
| ))) | ||||||
| } | ||||||
| Some(LiteralType::Struct(s)) => { | ||||||
| let mut builder = ScalarStructBuilder::new(); | ||||||
| for (i, field) in s.fields.iter().enumerate() { | ||||||
|
|
@@ -2013,6 +2097,29 @@ fn from_substrait_null( | |||||
| ), | ||||||
| } | ||||||
| } | ||||||
| r#type::Kind::Map(map) => { | ||||||
| let key_type = map.key.as_ref().ok_or_else(|| { | ||||||
| substrait_datafusion_err!("Map type must have key type") | ||||||
| })?; | ||||||
| let value_type = map.value.as_ref().ok_or_else(|| { | ||||||
| substrait_datafusion_err!("Map type must have value type") | ||||||
| })?; | ||||||
|
|
||||||
| let key_type = | ||||||
| from_substrait_type(key_type, extensions, dfs_names, name_idx)?; | ||||||
| let value_type = | ||||||
| from_substrait_type(value_type, extensions, dfs_names, name_idx)?; | ||||||
| let entries_field = Arc::new(Field::new_struct( | ||||||
| "entries", | ||||||
| vec![ | ||||||
| Field::new("key", key_type, false), | ||||||
| Field::new("value", value_type, true), | ||||||
| ], | ||||||
| false, | ||||||
| )); | ||||||
|
|
||||||
| DataType::Map(entries_field, false /* keys sorted */).try_into() | ||||||
| } | ||||||
| r#type::Kind::Struct(s) => { | ||||||
| let fields = | ||||||
| from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I adapted this logic from combining List and Struct hashing, the result seemed to make sense to me, but I'm not 100% confident in it