diff --git a/src/common.rs b/src/common.rs index 990ffe2..9cbee92 100644 --- a/src/common.rs +++ b/src/common.rs @@ -77,11 +77,8 @@ pub fn invoke> + 'static, I>( Some(ColumnarValue::Array(a)) => { if args.len() > 2 { // TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23 - return exec_err!( - "More than 1 path element is not supported when querying JSON using an array." - ); - } - if let Some(str_path_array) = a.as_any().downcast_ref::() { + exec_err!("More than 1 path element is not supported when querying JSON using an array.") + } else if let Some(str_path_array) = a.as_any().downcast_ref::() { let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); zip_apply(json_array, paths, jiter_find, true) } else if let Some(str_path_array) = a.as_any().downcast_ref::() { @@ -94,7 +91,7 @@ pub fn invoke> + 'static, I>( let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); zip_apply(json_array, paths, jiter_find, false) } else { - return exec_err!("unexpected second argument type, expected string or int array"); + exec_err!("unexpected second argument type, expected string or int array") } } Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find), diff --git a/src/common_union.rs b/src/common_union.rs index 7cc59db..ae4433b 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, OnceLock}; -use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, UnionArray}; +use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray}; use arrow::buffer::Buffer; use arrow_schema::{DataType, Field, UnionFields, UnionMode}; use datafusion_common::ScalarValue; @@ -42,7 +42,6 @@ pub(crate) fn json_from_union_scalar<'a>( #[derive(Debug)] pub(crate) struct JsonUnion { - nulls: Vec>, bools: Vec>, ints: Vec>, floats: Vec>, @@ -51,22 +50,21 @@ pub(crate) struct JsonUnion { objects: Vec>, type_ids: Vec, index: usize, - capacity: usize, + length: usize, } impl JsonUnion { - fn new(capacity: usize) -> Self { + fn new(length: usize) -> Self { Self { - nulls: vec![None; capacity], - bools: vec![None; capacity], - ints: vec![None; capacity], - floats: vec![None; capacity], - strings: vec![None; capacity], - arrays: vec![None; capacity], - objects: vec![None; capacity], - type_ids: vec![0; capacity], + bools: vec![None; length], + ints: vec![None; length], + floats: vec![None; length], + strings: vec![None; length], + arrays: vec![None; length], + objects: vec![None; length], + type_ids: vec![0; length], index: 0, - capacity, + length, } } @@ -77,7 +75,7 @@ impl JsonUnion { fn push(&mut self, field: JsonUnionField) { self.type_ids[self.index] = field.type_id(); match field { - JsonUnionField::JsonNull => self.nulls[self.index] = Some(true), + JsonUnionField::JsonNull => (), JsonUnionField::Bool(value) => self.bools[self.index] = Some(value), JsonUnionField::Int(value) => self.ints[self.index] = Some(value), JsonUnionField::Float(value) => self.floats[self.index] = Some(value), @@ -86,13 +84,12 @@ impl JsonUnion { JsonUnionField::Object(value) => self.objects[self.index] = Some(value), } self.index += 1; - debug_assert!(self.index <= self.capacity); + debug_assert!(self.index <= self.length); } fn push_none(&mut self) { - self.type_ids[self.index] = TYPE_ID_NULL; self.index += 1; - debug_assert!(self.index <= self.capacity); + debug_assert!(self.index <= self.length); } } @@ -119,7 +116,7 @@ impl TryFrom for UnionArray { fn try_from(value: JsonUnion) -> Result { let children: Vec> = vec![ - Arc::new(BooleanArray::from(value.nulls)), + Arc::new(NullArray::new(value.length)), Arc::new(BooleanArray::from(value.bools)), Arc::new(Int64Array::from(value.ints)), Arc::new(Float64Array::from(value.floats)), @@ -155,7 +152,7 @@ fn union_fields() -> UnionFields { FIELDS .get_or_init(|| { UnionFields::from_iter([ - (TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Boolean, true))), + (TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))), (TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))), (TYPE_ID_INT, Arc::new(Field::new("int", DataType::Int64, false))), (TYPE_ID_FLOAT, Arc::new(Field::new("float", DataType::Float64, false))), diff --git a/tests/main.rs b/tests/main.rs index 26a79fe..efa916e 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -68,7 +68,7 @@ async fn test_json_get_union() { "| object_foo | {str=abc} |", "| object_foo_array | {array=[1]} |", "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=true} |", + "| object_foo_null | {null=} |", "| object_bar | {null=} |", "| list_foo | {null=} |", "| invalid_json | {null=} |", @@ -675,7 +675,7 @@ async fn test_json_get_union_array_nested() { "+-------------+", "| {array=[0]} |", "| {null=} |", - "| {null=true} |", + "| {null=} |", "+-------------+", ]; @@ -725,7 +725,7 @@ async fn test_arrow() { "| object_foo | {str=abc} |", "| object_foo_array | {array=[1]} |", "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=true} |", + "| object_foo_null | {null=} |", "| object_bar | {null=} |", "| list_foo | {null=} |", "| invalid_json | {null=} |", @@ -903,7 +903,7 @@ async fn test_arrow_nested_columns() { "+-------------+", "| {array=[0]} |", "| {null=} |", - "| {null=true} |", + "| {null=} |", "+-------------+", ]; @@ -990,3 +990,112 @@ async fn test_question_filter() { ]; assert_batches_eq!(expected, &batches); } + +#[tokio::test] +async fn test_json_get_union_is_null() { + let batches = run_query("select name, json_get(json_data, 'foo') is null from test") + .await + .unwrap(); + + let expected = [ + "+------------------+----------------------------------------------+", + "| name | json_get(test.json_data,Utf8(\"foo\")) IS NULL |", + "+------------------+----------------------------------------------+", + "| object_foo | false |", + "| object_foo_array | false |", + "| object_foo_obj | false |", + "| object_foo_null | true |", + "| object_bar | true |", + "| list_foo | true |", + "| invalid_json | true |", + "+------------------+----------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_union_is_not_null() { + let batches = run_query("select name, json_get(json_data, 'foo') is not null from test") + .await + .unwrap(); + + let expected = [ + "+------------------+--------------------------------------------------+", + "| name | json_get(test.json_data,Utf8(\"foo\")) IS NOT NULL |", + "+------------------+--------------------------------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | false |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+--------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_arrow_union_is_null() { + let batches = run_query("select name, (json_data->'foo') is null from test") + .await + .unwrap(); + + let expected = [ + "+------------------+----------------------------------+", + "| name | json_data -> Utf8(\"foo\") IS NULL |", + "+------------------+----------------------------------+", + "| object_foo | false |", + "| object_foo_array | false |", + "| object_foo_obj | false |", + "| object_foo_null | true |", + "| object_bar | true |", + "| list_foo | true |", + "| invalid_json | true |", + "+------------------+----------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_arrow_union_is_not_null() { + let batches = run_query("select name, (json_data->'foo') is not null from test") + .await + .unwrap(); + + let expected = [ + "+------------------+--------------------------------------+", + "| name | json_data -> Utf8(\"foo\") IS NOT NULL |", + "+------------------+--------------------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | false |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+--------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_arrow_scalar_union_is_null() { + let batches = run_query( + r#" + select ('{"x": 1}'->'foo') is null as not_contains, + ('{"foo": 1}'->'foo') is null as contains_num, + ('{"foo": null}'->'foo') is null as contains_null"#, + ) + .await + .unwrap(); + + let expected = [ + "+--------------+--------------+---------------+", + "| not_contains | contains_num | contains_null |", + "+--------------+--------------+---------------+", + "| true | false | true |", + "+--------------+--------------+---------------+", + ]; + assert_batches_eq!(expected, &batches); +}