diff --git a/Cargo.toml b/Cargo.toml index 8438382..6d08b71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +arrow = "51.0.0" arrow-schema = "51.0.0" datafusion-common = "37.0.0" datafusion-expr = "37.0.0" @@ -13,7 +14,6 @@ log = "0.4.21" datafusion-execution = "37.0.0" [dev-dependencies] -arrow = "51.0.0" datafusion = "37.0.0" tokio = { version = "1.37.0", features = ["full"] } @@ -25,3 +25,4 @@ print_stdout = "warn" # certain lints which we don't want to enforce (for now) pedantic = { level = "warn", priority = -1 } missing_errors_doc = "allow" +cast_possible_truncation = "allow" diff --git a/README.md b/README.md index d6cce16..1b1b21c 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,25 @@ # datafusion-functions-json -methods to implement: +## Done +* [x] `json_get(json: str, *keys: str | int) -> JsonUnion` - Get a value from a JSON object by it's "path" +* [x] `json_get_str(json: str, *keys: str | int) -> str` - Get a string value from a JSON object by it's "path" +* [x] `json_get_int(json: str, *keys: str | int) -> int` - Get an integer value from a JSON object by it's "path" +* [x] `json_get_float(json: str, *keys: str | int) -> float` - Get a float value from a JSON object by it's "path" +* [x] `json_get_bool(json: str, *keys: str | int) -> bool` - Get a boolean value from a JSON object by it's "path" +* [x] `json_get_json(json: str, *keys: str | int) -> str` - Get any value from a JSON object by it's "path", represented as a string * [x] `json_obj_contains(json: str, key: str) -> bool` - true if a JSON object has a specific key -* [ ] `json_obj_contains_all(json: str, keys: list[str]) -> bool` - true if a JSON object has all of a list of keys -* [ ] `json_obj_contains_any(json: str, keys: list[str]) -> bool` - true if a JSON object has all of a list of keys + +## TODO + * [ ] `json_obj_keys(json: str) -> list[str]` - get the keys of a JSON object +* [ ] `json_length(json: str) -> int` - get the length of a JSON object or array * [ ] `json_obj_values(json: str) -> list[Any]` - get the values of a JSON object +* [ ] `json_obj_contains_all(json: str, keys: list[str]) -> bool` - true if a JSON object has all of the keys +* [ ] `json_obj_contains_any(json: str, keys: list[str]) -> bool` - true if a JSON object has any of the keys * [ ] `json_is_obj(json: str) -> bool` - true if the JSON is an object * [ ] `json_array_contains(json: str, key: Any) -> bool` - true if a JSON array has a specific value -* [ ] `json_array_items(json: str) -> list[Any]` - get the items of a JSON array +* [ ] `json_array_items_str(json: str) -> list[Any]` - get the items of a JSON array * [ ] `json_is_array(json: str) -> bool` - true if the JSON is an array -* [ ] `json_get(json: str, key: str | int) -> Any` - get the value of a key in a JSON object or array -* [ ] `json_get_path(json: str, key: list[str | int]) -> Any` - is this possible? -* [ ] `json_length(json: str) -> int` - get the length of a JSON object or array * [ ] `json_valid(json: str) -> bool` - true if the JSON is valid * [ ] `json_cast(json: str) -> Any` - cast the JSON to a native type??? diff --git a/src/common_get.rs b/src/common_get.rs new file mode 100644 index 0000000..d4da4fb --- /dev/null +++ b/src/common_get.rs @@ -0,0 +1,199 @@ +use std::str::Utf8Error; + +use arrow::array::{as_string_array, Array, ArrayRef, Int64Array, StringArray}; +use arrow_schema::DataType; +use datafusion_common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; +use datafusion_expr::ColumnarValue; +use jiter::{Jiter, JiterError, Peek}; + +pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> { + let first = match args.get(0) { + Some(arg) => arg, + None => return plan_err!("The `{fn_name}` function requires one or more arguments."), + }; + if !matches!(first, DataType::Utf8) { + return plan_err!("Unexpected argument type to `{fn_name}` at position 1, expected a string."); + } + args[1..].iter().enumerate().try_for_each(|(index, arg)| match arg { + DataType::Utf8 | DataType::UInt64 | DataType::Int64 => Ok(()), + _ => plan_err!( + "Unexpected argument type to `{fn_name}` at position {}, expected string or int.", + index + 2 + ), + }) +} + +#[derive(Debug)] +pub enum JsonPath<'s> { + Key(&'s str), + Index(usize), + None, +} + +impl From for JsonPath<'_> { + fn from(index: u64) -> Self { + JsonPath::Index(index as usize) + } +} + +impl From for JsonPath<'_> { + fn from(index: i64) -> Self { + match usize::try_from(index) { + Ok(i) => Self::Index(i), + Err(_) => Self::None, + } + } +} + +impl<'s> JsonPath<'s> { + pub fn extract_path(args: &'s [ColumnarValue]) -> Vec { + args[1..] + .iter() + .map(|arg| match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Self::Key(s), + ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => (*i).into(), + ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => (*i).into(), + _ => Self::None, + }) + .collect() + } +} + +pub fn get_invoke> + 'static, I>( + args: &[ColumnarValue], + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + to_array: impl Fn(C) -> DataFusionResult, + to_scalar: impl Fn(Option) -> ScalarValue, +) -> DataFusionResult { + match &args[0] { + ColumnarValue::Array(json_array) => { + let result_collect = match &args[1] { + ColumnarValue::Array(a) => { + if let Some(str_path_array) = a.as_any().downcast_ref::() { + let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); + zip_apply(json_array, paths, jiter_find) + } else if let Some(int_path_array) = a.as_any().downcast_ref::() { + let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); + zip_apply(json_array, paths, jiter_find) + } else { + return exec_err!("unexpected second argument type, expected string or int array"); + } + } + ColumnarValue::Scalar(_) => { + let path = JsonPath::extract_path(args); + as_string_array(json_array) + .iter() + .map(|opt_json| jiter_find(opt_json, &path).ok()) + .collect::() + } + }; + to_array(result_collect).map(ColumnarValue::from) + } + ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + let path = JsonPath::extract_path(args); + let v = jiter_find(s.as_ref().map(String::as_str), &path).ok(); + Ok(ColumnarValue::Scalar(to_scalar(v))) + } + ColumnarValue::Scalar(_) => { + exec_err!("unexpected first argument type, expected string") + } + } +} + +fn zip_apply<'a, P: Iterator>>, C: FromIterator> + 'static, I>( + json_array: &ArrayRef, + paths: P, + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, +) -> C { + as_string_array(json_array) + .iter() + .zip(paths) + .map(|(opt_json, opt_path)| { + if let Some(path) = opt_path { + jiter_find(opt_json, &[path]).ok() + } else { + None + } + }) + .collect::() +} + +pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { + if let Some(json_str) = opt_json { + let mut jiter = Jiter::new(json_str.as_bytes(), false); + if let Ok(peek) = jiter.peek() { + if let Ok(peek_found) = jiter_json_find_step(&mut jiter, peek, path) { + return Some((jiter, peek_found)); + } + } + } + None +} +macro_rules! get_err { + () => { + Err(GetError) + }; +} +pub(crate) use get_err; + +fn jiter_json_find_step(jiter: &mut Jiter, peek: Peek, path: &[JsonPath]) -> Result { + let (first, rest) = match path.split_first() { + Some(first_rest) => first_rest, + None => return Ok(peek), + }; + let next_peek = match peek { + Peek::Array => match first { + JsonPath::Index(index) => jiter_array_get(jiter, *index), + _ => get_err!(), + }, + Peek::Object => match first { + JsonPath::Key(key) => jiter_object_get(jiter, key), + _ => get_err!(), + }, + _ => get_err!(), + }?; + jiter_json_find_step(jiter, next_peek, rest) +} + +fn jiter_array_get(jiter: &mut Jiter, find_key: usize) -> Result { + let mut peek_opt = jiter.known_array()?; + + let mut index: usize = 0; + while let Some(peek) = peek_opt { + if index == find_key { + return Ok(peek); + } + jiter.next_skip()?; + index += 1; + peek_opt = jiter.array_step()?; + } + get_err!() +} + +fn jiter_object_get(jiter: &mut Jiter, find_key: &str) -> Result { + let mut opt_key = jiter.known_object()?; + + while let Some(key) = opt_key { + if key == find_key { + let value_peek = jiter.peek()?; + return Ok(value_peek); + } + jiter.next_skip()?; + opt_key = jiter.next_key()?; + } + get_err!() +} + +pub struct GetError; + +impl From for GetError { + fn from(_: JiterError) -> Self { + GetError + } +} + +impl From for GetError { + fn from(_: Utf8Error) -> Self { + GetError + } +} diff --git a/src/common_macros.rs b/src/common_macros.rs new file mode 100644 index 0000000..f3aa3b1 --- /dev/null +++ b/src/common_macros.rs @@ -0,0 +1,49 @@ +/// Creates external API `ScalarUDF` for an array UDF. Specifically, creates +/// +/// Creates a singleton `ScalarUDF` of the `$udf_impl` function named `$expr_fn_name _udf` and a +/// function named `$expr_fn_name _udf` which returns that function. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +/// +/// # Arguments +/// * `udf_impl`: name of the [`ScalarUDFImpl`] +/// * `expr_fn_name`: name of the `expr_fn` function to be created +/// * `arg`: 0 or more named arguments for the function +/// * `doc`: documentation string for the function +/// +/// Copied mostly from, `/datafusion/functions-array/src/macros.rs`. +/// +/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl +macro_rules! make_udf_function { + ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr) => { + paste::paste! { + #[doc = $doc] + #[must_use] pub fn $expr_fn_name($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + [< $expr_fn_name _udf >](), + vec![$($arg),*], + )) + } + + /// Singleton instance of [`$udf_impl`], ensures the UDF is only created once + /// named for example `STATIC_JSON_OBJ_CONTAINS` + static [< STATIC_ $expr_fn_name:upper >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + /// ScalarFunction that returns a [`ScalarUDF`] for [`$udf_impl`] + /// + /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + pub fn [< $expr_fn_name _udf >]() -> std::sync::Arc { + [< STATIC_ $expr_fn_name:upper >] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + <$udf_impl>::default(), + )) + }) + .clone() + } + } + }; +} + +pub(crate) use make_udf_function; diff --git a/src/common_union.rs b/src/common_union.rs new file mode 100644 index 0000000..05ecb6e --- /dev/null +++ b/src/common_union.rs @@ -0,0 +1,160 @@ +use std::sync::Arc; + +use arrow::array::{Array, BooleanArray, Float64Array, Int64Array, StringArray, UnionArray}; +use arrow::buffer::Buffer; +use arrow_schema::{DataType, Field, UnionFields, UnionMode}; +use datafusion_common::ScalarValue; + +#[derive(Debug)] +pub(crate) struct JsonUnion { + nulls: Vec>, + bools: Vec>, + ints: Vec>, + floats: Vec>, + strings: Vec>, + arrays: Vec>, + objects: Vec>, + type_ids: Vec, + index: usize, + capacity: usize, +} + +impl JsonUnion { + fn new(capacity: usize) -> Self { + Self { + nulls: vec![None; capacity], + bools: vec![None; capacity], + ints: vec![None; capacity], + floats: vec![None; capacity], + strings: vec![None; capacity], + arrays: vec![None; capacity], + objects: vec![None; capacity], + type_ids: vec![0; capacity], + index: 0, + capacity, + } + } + + pub fn data_type() -> DataType { + DataType::Union( + UnionFields::new(TYPE_IDS.to_vec(), union_fields().to_vec()), + UnionMode::Sparse, + ) + } + + fn push(&mut self, field: JsonUnionField) { + self.type_ids[self.index] = field.type_id(); + match field { + JsonUnionField::JsonNull => self.nulls[self.index] = Some(true), + JsonUnionField::Bool(value) => self.bools[self.index] = Some(value), + JsonUnionField::Int(value) => self.ints[self.index] = Some(value), + JsonUnionField::Float(value) => self.floats[self.index] = Some(value), + JsonUnionField::Str(value) => self.strings[self.index] = Some(value), + JsonUnionField::Array(value) => self.arrays[self.index] = Some(value), + JsonUnionField::Object(value) => self.objects[self.index] = Some(value), + } + self.index += 1; + debug_assert!(self.index <= self.capacity); + } + + fn push_none(&mut self) { + self.type_ids[self.index] = TYPE_IDS[0]; + self.index += 1; + debug_assert!(self.index <= self.capacity); + } +} + +/// So we can do `collect::()` +impl FromIterator> for JsonUnion { + fn from_iter>>(iter: I) -> Self { + let inner = iter.into_iter(); + let (lower, upper) = inner.size_hint(); + let mut union = Self::new(upper.unwrap_or(lower)); + + for opt_field in inner { + if let Some(union_field) = opt_field { + union.push(union_field); + } else { + union.push_none(); + } + } + union + } +} + +impl TryFrom for UnionArray { + type Error = arrow::error::ArrowError; + + fn try_from(value: JsonUnion) -> Result { + let [f0, f1, f2, f3, f4, f5, f6] = union_fields(); + let children: Vec<(Field, Arc)> = vec![ + (f0, Arc::new(BooleanArray::from(value.nulls))), + (f1, Arc::new(BooleanArray::from(value.bools))), + (f2, Arc::new(Int64Array::from(value.ints))), + (f3, Arc::new(Float64Array::from(value.floats))), + (f4, Arc::new(StringArray::from(value.strings))), + (f5, Arc::new(StringArray::from(value.arrays))), + (f6, Arc::new(StringArray::from(value.objects))), + ]; + UnionArray::try_new(TYPE_IDS, Buffer::from_slice_ref(&value.type_ids), None, children) + } +} + +#[derive(Debug)] +pub(crate) enum JsonUnionField { + JsonNull, + Bool(bool), + Int(i64), + Float(f64), + Str(String), + Array(String), + Object(String), +} + +const TYPE_IDS: &[i8] = &[0, 1, 2, 3, 4, 5, 6]; + +fn union_fields() -> [Field; 7] { + [ + Field::new("null", DataType::Boolean, true), + Field::new("bool", DataType::Boolean, false), + Field::new("int", DataType::Int64, false), + Field::new("float", DataType::Float64, false), + Field::new("str", DataType::Utf8, false), + Field::new("array", DataType::Utf8, false), + Field::new("object", DataType::Utf8, false), + ] +} + +impl JsonUnionField { + fn type_id(&self) -> i8 { + match self { + Self::JsonNull => 0, + Self::Bool(_) => 1, + Self::Int(_) => 2, + Self::Float(_) => 3, + Self::Str(_) => 4, + Self::Array(_) => 5, + Self::Object(_) => 6, + } + } + + pub fn scalar_value(f: Option) -> ScalarValue { + ScalarValue::Union( + f.map(|f| (f.type_id(), Box::new(f.into()))), + UnionFields::new(TYPE_IDS.to_vec(), union_fields().to_vec()), + UnionMode::Sparse, + ) + } +} + +impl From for ScalarValue { + fn from(value: JsonUnionField) -> Self { + match value { + JsonUnionField::JsonNull => Self::Null, + JsonUnionField::Bool(b) => Self::Boolean(Some(b)), + JsonUnionField::Int(i) => Self::Int64(Some(i)), + JsonUnionField::Float(f) => Self::Float64(Some(f)), + JsonUnionField::Str(s) | JsonUnionField::Array(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), + } + } +} diff --git a/src/json_get.rs b/src/json_get.rs new file mode 100644 index 0000000..4a46f80 --- /dev/null +++ b/src/json_get.rs @@ -0,0 +1,111 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::UnionArray; +use arrow_schema::DataType; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::Result as DataFusionResult; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::{Jiter, NumberAny, NumberInt, Peek}; + +use crate::common_get::{check_args, get_err, get_invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; +use crate::common_union::{JsonUnion, JsonUnionField}; + +make_udf_function!( + JsonGet, + json_get, + json_data key, // arg name + r#"Get a value from a JSON object by it's "path""# +); + +// build_typed_get!(JsonGet, "json_get", Union, Float64Array, jiter_json_get_float); + +#[derive(Debug)] +pub(super) struct JsonGet { + signature: Signature, + aliases: Vec, +} + +impl Default for JsonGet { + fn default() -> Self { + Self { + signature: Signature::variadic(vec![DataType::Utf8, DataType::UInt64], Volatility::Immutable), + aliases: vec!["json_get".to_string(), "json_get_union".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGet { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| JsonUnion::data_type()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + let to_array = |c: JsonUnion| { + let array: UnionArray = c.try_into()?; + Ok(Arc::new(array) as ArrayRef) + }; + get_invoke::(args, jiter_json_get_union, to_array, JsonUnionField::scalar_value) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { + build_union(&mut jiter, peek) + } else { + get_err!() + } +} + +fn build_union(jiter: &mut Jiter, peek: Peek) -> Result { + match peek { + Peek::Null => { + jiter.known_null()?; + Ok(JsonUnionField::JsonNull) + } + Peek::True | Peek::False => { + let value = jiter.known_bool(peek)?; + Ok(JsonUnionField::Bool(value)) + } + Peek::String => { + let value = jiter.known_str()?; + Ok(JsonUnionField::Str(value.to_owned())) + } + Peek::Array => { + let start = jiter.current_index(); + jiter.known_skip(peek)?; + let array_slice = jiter.slice_to_current(start); + let array_string = std::str::from_utf8(array_slice)?; + Ok(JsonUnionField::Array(array_string.to_owned())) + } + Peek::Object => { + let start = jiter.current_index(); + jiter.known_skip(peek)?; + let object_slice = jiter.slice_to_current(start); + let object_string = std::str::from_utf8(object_slice)?; + Ok(JsonUnionField::Object(object_string.to_owned())) + } + _ => match jiter.known_number(peek)? { + NumberAny::Int(NumberInt::Int(value)) => Ok(JsonUnionField::Int(value)), + NumberAny::Int(NumberInt::BigInt(_)) => todo!("BigInt not supported yet"), + NumberAny::Float(value) => Ok(JsonUnionField::Float(value)), + }, + } +} diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs new file mode 100644 index 0000000..093771f --- /dev/null +++ b/src/json_get_bool.rs @@ -0,0 +1,75 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow_schema::DataType; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +use crate::common_get::{check_args, get_err, get_invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonGetBool, + json_get_bool, + json_data path, // arg name + r#"Get an boolean value from a JSON object by it's "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetBool { + signature: Signature, + aliases: Vec, +} + +impl Default for JsonGetBool { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["json_get_bool".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetBool { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + get_invoke::( + args, + jiter_json_get_bool, + |c| Ok(Arc::new(c) as ArrayRef), + ScalarValue::Boolean, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + Peek::True | Peek::False => Ok(jiter.known_bool(peek)?), + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/json_get_float.rs b/src/json_get_float.rs new file mode 100644 index 0000000..361c407 --- /dev/null +++ b/src/json_get_float.rs @@ -0,0 +1,87 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array}; +use arrow_schema::DataType; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::{NumberAny, Peek}; + +use crate::common_get::{check_args, get_err, get_invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonGetFloat, + json_get_float, + json_data path, // arg name + r#"Get a float value from a JSON object by it's "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetFloat { + signature: Signature, + aliases: Vec, +} + +impl Default for JsonGetFloat { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["json_get_float".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetFloat { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::Float64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + get_invoke::( + args, + jiter_json_get_float, + |c| Ok(Arc::new(c) as ArrayRef), + ScalarValue::Float64, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + // numbers are represented by everything else in peek, hence doing it this way + Peek::Null + | Peek::True + | Peek::False + | Peek::Minus + | Peek::Infinity + | Peek::NaN + | Peek::String + | Peek::Array + | Peek::Object => get_err!(), + _ => match jiter.known_number(peek)? { + NumberAny::Float(f) => Ok(f), + NumberAny::Int(int) => Ok(int.into()), + }, + } + } else { + get_err!() + } +} diff --git a/src/json_get_int.rs b/src/json_get_int.rs new file mode 100644 index 0000000..1917160 --- /dev/null +++ b/src/json_get_int.rs @@ -0,0 +1,87 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow_schema::DataType; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::{NumberInt, Peek}; + +use crate::common_get::{check_args, get_err, get_invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonGetInt, + json_get_int, + json_data path, // arg name + r#"Get an integer value from a JSON object by it's "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetInt { + signature: Signature, + aliases: Vec, +} + +impl Default for JsonGetInt { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["json_get_int".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetInt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + get_invoke::( + args, + jiter_json_get_int, + |c| Ok(Arc::new(c) as ArrayRef), + ScalarValue::Int64, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + // numbers are represented by everything else in peek, hence doing it this way + Peek::Null + | Peek::True + | Peek::False + | Peek::Minus + | Peek::Infinity + | Peek::NaN + | Peek::String + | Peek::Array + | Peek::Object => get_err!(), + _ => match jiter.known_int(peek)? { + NumberInt::Int(i) => Ok(i), + NumberInt::BigInt(_) => get_err!(), + }, + } + } else { + get_err!() + } +} diff --git a/src/json_get_json.rs b/src/json_get_json.rs new file mode 100644 index 0000000..fad534e --- /dev/null +++ b/src/json_get_json.rs @@ -0,0 +1,75 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow_schema::DataType; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::common_get::{check_args, get_err, get_invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonGetJson, + json_get_json, + json_data path, // arg name + r#"Get any value from a JSON object by it's "path", represented as a string"# +); + +#[derive(Debug)] +pub(super) struct JsonGetJson { + signature: Signature, + aliases: Vec, +} + +impl Default for JsonGetJson { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["json_get_json".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + get_invoke::( + args, + jiter_json_get_json, + |c| Ok(Arc::new(c) as ArrayRef), + ScalarValue::Utf8, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { + let start = jiter.current_index(); + jiter.known_skip(peek)?; + let object_slice = jiter.slice_to_current(start); + let object_string = std::str::from_utf8(object_slice)?; + Ok(object_string.to_owned()) + } else { + get_err!() + } +} diff --git a/src/json_get_str.rs b/src/json_get_str.rs new file mode 100644 index 0000000..a1d3213 --- /dev/null +++ b/src/json_get_str.rs @@ -0,0 +1,75 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow_schema::DataType; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +use crate::common_get::{check_args, get_err, get_invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonGetStr, + json_get_str, + json_data path, // arg name + r#"Get a string value from a JSON object by it's "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetStr { + signature: Signature, + aliases: Vec, +} + +impl Default for JsonGetStr { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["json_get_str".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetStr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + get_invoke::( + args, + jiter_json_get_str, + |c| Ok(Arc::new(c) as ArrayRef), + ScalarValue::Utf8, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + Peek::String => Ok(jiter.known_str()?.to_owned()), + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/json_obj_contains.rs b/src/json_obj_contains.rs index aa9d619..de09e68 100644 --- a/src/json_obj_contains.rs +++ b/src/json_obj_contains.rs @@ -1,19 +1,20 @@ -use crate::macros::make_udf_function; +use std::any::Any; +use std::sync::Arc; + use arrow_schema::DataType; use arrow_schema::DataType::{LargeUtf8, Utf8}; use datafusion_common::arrow::array::{as_string_array, ArrayRef, BooleanArray}; use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use jiter::Jiter; -use std::any::Any; -use std::sync::Arc; +use jiter::{Jiter, JiterResult}; + +use crate::common_macros::make_udf_function; make_udf_function!( JsonObjContains, json_obj_contains, json_data key, // arg name - "Does the string exist as a top-level key within the JSON value?", // doc - json_obj_contains_udf // internal function name + "Does the string exist as a top-level key within the JSON value?" ); #[derive(Debug)] @@ -22,8 +23,8 @@ pub(super) struct JsonObjContains { aliases: Vec, } -impl JsonObjContains { - pub fn new() -> Self { +impl Default for JsonObjContains { + fn default() -> Self { Self { signature: Signature::uniform(2, vec![Utf8, LargeUtf8], Volatility::Immutable), aliases: vec!["json_obj_contains".to_string(), "json_object_contains".to_string()], @@ -48,13 +49,13 @@ impl ScalarUDFImpl for JsonObjContains { match arg_types[0] { Utf8 | LargeUtf8 => Ok(DataType::Boolean), _ => { - plan_err!("The json_obj_contains function can only accept Utf8 or LargeUtf8.") + plan_err!("The `json_obj_contains` function can only accept Utf8 or LargeUtf8.") } } } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let json_haystack = match &args[0] { + let json_data = match &args[0] { ColumnarValue::Array(array) => as_string_array(array), ColumnarValue::Scalar(_) => { return exec_err!("json_obj_contains first argument: unexpected argument type, expected string array") @@ -66,9 +67,9 @@ impl ScalarUDFImpl for JsonObjContains { _ => return exec_err!("json_obj_contains second argument: unexpected argument type, expected string"), }; - let array = json_haystack + let array = json_data .iter() - .map(|opt_json| opt_json.map(|json| jiter_json_contains(json.as_bytes(), &needle))) + .map(|opt_json| opt_json.map(|json| jiter_json_contains(json.as_bytes(), &needle).unwrap_or(false))) .collect::(); Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) @@ -79,25 +80,21 @@ impl ScalarUDFImpl for JsonObjContains { } } -fn jiter_json_contains(json_data: &[u8], expected_key: &str) -> bool { +fn jiter_json_contains(json_data: &[u8], expected_key: &str) -> JiterResult { let mut jiter = Jiter::new(json_data, false); - let Ok(Some(first_key)) = jiter.next_object() else { - return false; + let Some(first_key) = jiter.next_object()? else { + return Ok(false); }; if first_key == expected_key { - return true; - } - if jiter.next_skip().is_err() { - return false; + return Ok(true); } + jiter.next_skip()?; while let Ok(Some(key)) = jiter.next_key() { if key == expected_key { - return true; - } - if jiter.next_skip().is_err() { - return false; + return Ok(true); } + jiter.next_skip()?; } - false + Ok(false) } diff --git a/src/lib.rs b/src/lib.rs index 09be0f9..014dc13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,16 +4,39 @@ use datafusion_expr::ScalarUDF; use log::debug; use std::sync::Arc; +mod common_get; +mod common_macros; +mod common_union; +mod json_get; +mod json_get_bool; +mod json_get_float; +mod json_get_int; +mod json_get_json; +mod json_get_str; mod json_obj_contains; -mod macros; +mod rewrite; pub mod functions { + pub use crate::json_get::json_get; + pub use crate::json_get_bool::json_get_bool; + pub use crate::json_get_float::json_get_float; + pub use crate::json_get_int::json_get_int; + pub use crate::json_get_json::json_get_json; + pub use crate::json_get_str::json_get_str; pub use crate::json_obj_contains::json_obj_contains; } /// Register all JSON UDFs pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = vec![json_obj_contains::json_obj_contains_udf()]; + let functions: Vec> = vec![ + json_get::json_get_udf(), + json_get_bool::json_get_bool_udf(), + json_get_float::json_get_float_udf(), + json_get_int::json_get_int_udf(), + json_get_json::json_get_json_udf(), + json_get_str::json_get_str_udf(), + json_obj_contains::json_obj_contains_udf(), + ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; if let Some(existing_udf) = existing_udf { @@ -21,6 +44,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { } Ok(()) as Result<()> })?; + registry.register_function_rewrite(Arc::new(rewrite::JsonFunctionRewriter))?; Ok(()) } diff --git a/src/macros.rs b/src/macros.rs deleted file mode 100644 index 240c0da..0000000 --- a/src/macros.rs +++ /dev/null @@ -1,98 +0,0 @@ -#[allow(clippy::doc_markdown)] -/// Currently copied verbatim, can hopefully be replaced or simplified -/// https://github.com/apache/datafusion/blob/19356b26f515149f96f9b6296975a77ac7260149/datafusion/functions-array/src/macros.rs -/// -/// Creates external API functions for an array UDF. Specifically, creates -/// -/// 1. Single `ScalarUDF` instance -/// -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that function named $NAME. -/// -/// This is used to ensure creating the list of `ScalarUDF` only happens once. -/// -/// # 2. `expr_fn` style function -/// -/// These are functions that create an `Expr` that invokes the UDF, used -/// primarily to programmatically create expressions. -/// -/// For example: -/// ```text -/// pub fn array_to_string(delimiter: Expr) -> Expr { -/// ... -/// } -/// ``` -/// # Arguments -/// * `UDF`: name of the [`ScalarUDFImpl`] -/// * `EXPR_FN`: name of the `expr_fn` function to be created -/// * `arg`: 0 or more named arguments for the function -/// * `DOC`: documentation string for the function -/// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` -/// * `GNAME`: name for the single static instance of the `ScalarUDF` -/// -/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl -macro_rules! make_udf_function { - ($UDF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr , $SCALAR_UDF_FN:ident) => { - paste::paste! { - // "fluent expr_fn" style function - #[doc = $DOC] - #[must_use] pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( - $SCALAR_UDF_FN(), - vec![$($arg),*], - )) - } - - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() - } - } - }; - ($UDF:ty, $EXPR_FN:ident, $DOC:expr , $SCALAR_UDF_FN:ident) => { - paste::paste! { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( - $SCALAR_UDF_FN(), - arg, - )) - } - - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() - } - } - }; -} - -pub(crate) use make_udf_function; diff --git a/src/rewrite.rs b/src/rewrite.rs new file mode 100644 index 0000000..3451118 --- /dev/null +++ b/src/rewrite.rs @@ -0,0 +1,46 @@ +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::Transformed; +use datafusion_common::DFSchema; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::{Expr, ScalarFunctionDefinition}; + +pub(crate) struct JsonFunctionRewriter; + +impl FunctionRewrite for JsonFunctionRewriter { + fn name(&self) -> &str { + "JsonFunctionRewriter" + } + + fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result> { + if let Expr::Cast(cast) = &expr { + if let Expr::ScalarFunction(func) = &*cast.expr { + if let ScalarFunctionDefinition::UDF(udf) = &func.func_def { + if udf.name() == "json_get" { + if let Some(t) = switch_json_get(&cast.data_type, &func.args) { + return Ok(t); + } + } + } + } + } + Ok(Transformed::no(expr)) + } +} + +fn switch_json_get(cast_data_type: &DataType, args: &[Expr]) -> Option> { + let udf = match cast_data_type { + DataType::Boolean => crate::json_get_bool::json_get_bool_udf(), + DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(), + DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(), + DataType::Utf8 => crate::json_get_str::json_get_str_udf(), + _ => return None, + }; + let f = ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(udf), + args: args.to_vec(), + }; + Some(Transformed::yes(Expr::ScalarFunction(f))) +} diff --git a/tests/test_json_get.rs b/tests/test_json_get.rs new file mode 100644 index 0000000..9bf2cf5 --- /dev/null +++ b/tests/test_json_get.rs @@ -0,0 +1,259 @@ +use arrow_schema::DataType; +use datafusion::assert_batches_eq; + +mod utils; +use utils::{display_val, run_query}; + +#[tokio::test] +async fn test_json_get_union() { + let batches = run_query("select name, json_get(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+--------------------------------------+", + "| name | json_get(test.json_data,Utf8(\"foo\")) |", + "+------------------+--------------------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=true} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_equals() { + let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test") + .await + .unwrap_err(); + + // see https://github.com/apache/datafusion/issues/10180 + assert!(e + .to_string() + .starts_with("Error during planning: Cannot infer common argument type for comparison operation Union")); +} + +#[tokio::test] +async fn test_json_get_cast_equals() { + let batches = run_query(r"select name, json_get(json_data, 'foo')::string='abc' from test") + .await + .unwrap(); + + let expected = [ + "+------------------+----------------------------------------------------+", + "| name | json_get(test.json_data,Utf8(\"foo\")) = Utf8(\"abc\") |", + "+------------------+----------------------------------------------------+", + "| object_foo | true |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+----------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_str() { + let batches = run_query("select name, json_get_str(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+------------------------------------------+", + "| name | json_get_str(test.json_data,Utf8(\"foo\")) |", + "+------------------+------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_str_equals() { + let sql = "select name, json_get_str(json_data, 'foo')='abc' from test"; + let batches = run_query(sql).await.unwrap(); + + let expected = [ + "+------------------+--------------------------------------------------------+", + "| name | json_get_str(test.json_data,Utf8(\"foo\")) = Utf8(\"abc\") |", + "+------------------+--------------------------------------------------------+", + "| object_foo | true |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+--------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_str_int() { + let sql = r#"select json_get_str('["a", "b", "c"]', 1) as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Utf8, "b".to_string())); + + let sql = r#"select json_get_str('["a", "b", "c"]', 3) as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Utf8, String::new())); +} + +#[tokio::test] +async fn test_json_get_str_path() { + let sql = r#"select json_get_str('{"a": {"aa": "x", "ab: "y"}, "b": []}', 'a', 'aa') as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Utf8, "x".to_string())); +} + +#[tokio::test] +async fn test_json_get_str_null() { + let e = run_query(r"select json_get_str('{}', null)").await.unwrap_err(); + + assert_eq!( + e.to_string(), + "Error during planning: Unexpected argument type to `json_get_str` at position 2, expected string or int." + ); +} + +#[tokio::test] +async fn test_json_get_no_path() { + let batches = run_query(r#"select json_get('"foo"')::string"#).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Utf8, "foo".to_string())); + + let batches = run_query(r#"select json_get('123')::int"#).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "123".to_string())); + + let batches = run_query(r#"select json_get('true')::int"#).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "".to_string())); +} + +#[tokio::test] +async fn test_json_get_int() { + let batches = run_query(r"select json_get_int('[1, 2, 3]', 1) as v").await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "2".to_string())); +} + +#[tokio::test] +async fn test_json_get_cast_int() { + let sql = r#"select json_get('{"foo": 42}', 'foo')::int as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + + // floats not allowed + let sql = r#"select json_get('{"foo": 4.2}', 'foo')::int as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, String::new())); +} + +#[tokio::test] +async fn test_json_get_cast_int_path() { + let sql = r#"select json_get('{"foo": [null, {"x": false, "bar": 73}}', 'foo', 1, 'bar')::int as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "73".to_string())); +} + +#[tokio::test] +async fn test_json_get_int_lookup() { + let sql = "select str_key, json_data from other where json_get_int(json_data, str_key) is not null"; + let batches = run_query(sql).await.unwrap(); + let expected = [ + "+---------+---------------+", + "| str_key | json_data |", + "+---------+---------------+", + "| foo | {\"foo\": 42} |", + "+---------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + + // lookup by int + let sql = "select int_key, json_data from other where json_get_int(json_data, int_key) is not null"; + let batches = run_query(sql).await.unwrap(); + let expected = [ + "+---------+-----------+", + "| int_key | json_data |", + "+---------+-----------+", + "| 0 | [42] |", + "+---------+-----------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_float() { + let batches = run_query("select json_get_float('[1.5]', 0) as v").await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "1.5".to_string())); + + // coerce int to float + let batches = run_query("select json_get_float('[1]', 0) as v").await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "1.0".to_string())); +} + +#[tokio::test] +async fn test_json_get_cast_float() { + let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::float as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); +} + +#[tokio::test] +async fn test_json_get_bool() { + let batches = run_query("select json_get_bool('[true]', 0) as v").await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); + + let batches = run_query(r#"select json_get_bool('{"": false}', '') as v"#) + .await + .unwrap(); + assert_eq!(display_val(batches).await, (DataType::Boolean, "false".to_string())); +} + +#[tokio::test] +async fn test_json_get_cast_bool() { + let sql = r#"select json_get('{"foo": true}', 'foo')::bool as v"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); +} + +#[tokio::test] +async fn test_json_get_json() { + let batches = run_query("select name, json_get_json(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-------------------------------------------+", + "| name | json_get_json(test.json_data,Utf8(\"foo\")) |", + "+------------------+-------------------------------------------+", + "| object_foo | \"abc\" |", + "| object_foo_array | [1] |", + "| object_foo_obj | {} |", + "| object_foo_null | null |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_json_float() { + let sql = r#"select json_get_json('{"x": 4.2e-1}', 'x')"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Utf8, "4.2e-1".to_string())); +} diff --git a/tests/test_json_obj_contains.rs b/tests/test_json_obj_contains.rs index 20c704c..35062c8 100644 --- a/tests/test_json_obj_contains.rs +++ b/tests/test_json_obj_contains.rs @@ -1,69 +1,26 @@ -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::{array::StringArray, record_batch::RecordBatch}; -use std::sync::Arc; - use datafusion::assert_batches_eq; -use datafusion::error::Result; -use datafusion::execution::context::SessionContext; -use datafusion_functions_json::register_all; - -async fn create_test_table() -> Result { - let schema = Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("json_data", DataType::Utf8, false), - ])); - - let data = [ - ("object_foo", r#" {"foo": 123} "#), - ("object_bar", r#" {"bar": true} "#), - ("list_foo", r#" ["foo"] "#), - ("invalid_json", "is not json"), - ]; - - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(StringArray::from( - data.iter().map(|(name, _)| *name).collect::>(), - )), - Arc::new(StringArray::from( - data.iter().map(|(_, json)| *json).collect::>(), - )), - ], - )?; - let mut ctx = SessionContext::new(); - register_all(&mut ctx)?; - ctx.register_batch("test", batch)?; - Ok(ctx) -} - -/// Executes an expression on the test dataframe as a select. -/// Compares formatted output of a record batch with an expected -/// vector of strings, using the `assert_batch_eq`! macro -macro_rules! query { - ($sql:expr, $expected: expr) => { - let ctx = create_test_table().await?; - let df = ctx.sql($sql).await?; - let batches = df.collect().await?; - - assert_batches_eq!($expected, &batches); - }; -} +mod utils; +use utils::run_query; #[tokio::test] -async fn test_json_obj_contains() -> Result<()> { +async fn test_json_obj_contains() { let expected = [ - "+--------------+-----------------------------------------------+", - "| name | json_obj_contains(test.json_data,Utf8(\"foo\")) |", - "+--------------+-----------------------------------------------+", - "| object_foo | true |", - "| object_bar | false |", - "| list_foo | false |", - "| invalid_json | false |", - "+--------------+-----------------------------------------------+", + "+------------------+-----------------------------------------------+", + "| name | json_obj_contains(test.json_data,Utf8(\"foo\")) |", + "+------------------+-----------------------------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | true |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+-----------------------------------------------+", ]; - query!("select name, json_obj_contains(json_data, 'foo') from test", expected); - Ok(()) + let batches = run_query("select name, json_obj_contains(json_data, 'foo') from test") + .await + .unwrap(); + assert_batches_eq!(expected, &batches); } diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs new file mode 100644 index 0000000..c84cdb8 --- /dev/null +++ b/tests/utils/mod.rs @@ -0,0 +1,87 @@ +#![allow(dead_code)] +use arrow::array::Int64Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::util::display::{ArrayFormatter, FormatOptions}; +use arrow::{array::StringArray, record_batch::RecordBatch}; +use std::sync::Arc; + +use datafusion::error::Result; +use datafusion::execution::context::SessionContext; +use datafusion_functions_json::register_all; + +async fn create_test_table() -> Result { + let mut ctx = SessionContext::new(); + register_all(&mut ctx)?; + + let test_data = [ + ("object_foo", r#" {"foo": "abc"} "#), + ("object_foo_array", r#" {"foo": [1]} "#), + ("object_foo_obj", r#" {"foo": {}} "#), + ("object_foo_null", r#" {"foo": null} "#), + ("object_bar", r#" {"bar": true} "#), + ("list_foo", r#" ["foo"] "#), + ("invalid_json", "is not json"), + ]; + let test_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("json_data", DataType::Utf8, false), + ])), + vec![ + Arc::new(StringArray::from( + test_data.iter().map(|(name, _)| *name).collect::>(), + )), + Arc::new(StringArray::from( + test_data.iter().map(|(_, json)| *json).collect::>(), + )), + ], + )?; + ctx.register_batch("test", test_batch)?; + + let other_data = [ + (r#" {"foo": 42} "#, "foo", 0), + (r#" {"foo": 42} "#, "bar", 1), + (r" [42] ", "foo", 0), + (r" [42] ", "bar", 1), + ]; + let other_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("json_data", DataType::Utf8, false), + Field::new("str_key", DataType::Utf8, false), + Field::new("int_key", DataType::Int64, false), + ])), + vec![ + Arc::new(StringArray::from( + other_data.iter().map(|(json, _, _)| *json).collect::>(), + )), + Arc::new(StringArray::from( + other_data.iter().map(|(_, str_key, _)| *str_key).collect::>(), + )), + Arc::new(Int64Array::from( + other_data.iter().map(|(_, _, int_key)| *int_key).collect::>(), + )), + ], + )?; + ctx.register_batch("other", other_batch)?; + + Ok(ctx) +} + +pub async fn run_query(sql: &str) -> Result> { + let ctx = create_test_table().await?; + let df = ctx.sql(sql).await?; + df.collect().await +} + +pub async fn display_val(batch: Vec) -> (DataType, String) { + assert_eq!(batch.len(), 1); + let batch = batch.first().unwrap(); + assert_eq!(batch.num_rows(), 1); + let schema = batch.schema(); + let schema_col = schema.field(0); + let c = batch.column(0); + let options = FormatOptions::default().with_display_error(true); + let f = ArrayFormatter::try_new(c.as_ref(), &options).unwrap(); + let repr = f.value(0).try_to_string().unwrap(); + (schema_col.data_type().clone(), repr) +}