Skip to content
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"

[dependencies]
arrow = "51.0.0"
arrow-schema = "51.0.0"
datafusion-common = "37.0.0"
datafusion-expr = "37.0.0"
Expand All @@ -13,7 +14,6 @@ log = "0.4.21"
datafusion-execution = "37.0.0"

[dev-dependencies]
arrow = "51.0.0"
datafusion = "37.0.0"
tokio = { version = "1.37.0", features = ["full"] }

Expand All @@ -25,3 +25,4 @@ print_stdout = "warn"
# certain lints which we don't want to enforce (for now)
pedantic = { level = "warn", priority = -1 }
missing_errors_doc = "allow"
cast_possible_truncation = "allow"
203 changes: 203 additions & 0 deletions src/json_get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{Array, UnionArray};
use arrow_schema::DataType;
use datafusion_common::arrow::array::{as_string_array, ArrayRef};
use datafusion_common::{exec_err, plan_err, Result as DfResult, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::{Jiter, JiterError, NumberAny, NumberInt, Peek};

use crate::macros::make_udf_function;
use crate::union::{JsonUnion, JsonUnionField};

make_udf_function!(
JsonGet,
json_get,
json_data key, // arg name
r#"Get a value from a JSON object by it's "path""#
);

#[derive(Debug)]
pub(super) struct JsonGet {
signature: Signature,
aliases: Vec<String>,
}

impl JsonGet {
pub fn new() -> Self {
Self {
signature: Signature::variadic(vec![DataType::Utf8, DataType::UInt64], Volatility::Immutable),
aliases: vec!["json_get".to_string()],
}
}
}

impl ScalarUDFImpl for JsonGet {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"json_get"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DfResult<DataType> {
if arg_types.len() < 2 {
return plan_err!("The `json_get` function requires two or more arguments.");
}
match arg_types[0] {
DataType::Utf8 => Ok(JsonUnion::data_type()),
_ => {
plan_err!("The `json_get` function can only accept Utf8 or LargeUtf8.")
}
}
}

fn invoke(&self, args: &[ColumnarValue]) -> DfResult<ColumnarValue> {
let json_data = match &args[0] {
ColumnarValue::Array(array) => as_string_array(array),
ColumnarValue::Scalar(_) => {
return exec_err!("json_get first argument: unexpected argument type, expected string array")
}
};

let path = args[1..]
.iter()
.enumerate()
.map(|(index, arg)| match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Ok(JsonPath::Key(s)),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => Ok(JsonPath::Index(*i as usize)),
_ => exec_err!(
"json_get: unexpected argument type at {}, expected string or int",
index + 2
),
})
.collect::<DfResult<Vec<JsonPath>>>()?;

let mut union = JsonUnion::new(json_data.len());
for opt_json in json_data {
if let Some(union_field) = jiter_json_get(opt_json, &path) {
union.push(union_field);
} else {
union.push_none();
}
}
let array: UnionArray = union.try_into()?;

Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

enum JsonPath<'s> {
Key(&'s str),
Index(usize),
}

struct GetError;

impl From<JiterError> for GetError {
fn from(_: JiterError) -> Self {
GetError
}
}

fn jiter_json_get(opt_json: Option<&str>, path: &[JsonPath]) -> Option<JsonUnionField> {
if let Some(json_str) = opt_json {
let mut jiter = Jiter::new(json_str.as_bytes(), false);
if let Ok(peek) = jiter.peek() {
return _jiter_json_get(&mut jiter, peek, path).ok();
}
}
None
}

fn _jiter_json_get(jiter: &mut Jiter, peek: Peek, path: &[JsonPath]) -> Result<JsonUnionField, GetError> {
let (first, rest) = path.split_first().unwrap();
let next_peek = match peek {
Peek::Array => match first {
JsonPath::Index(index) => jiter_array_get(jiter, *index),
JsonPath::Key(_) => Err(GetError),
},
Peek::Object => match first {
JsonPath::Key(key) => jiter_object_get(jiter, key),
JsonPath::Index(_) => Err(GetError),
},
_ => Err(GetError),
}?;

if rest.is_empty() {
match next_peek {
Peek::Null => {
jiter.known_null()?;
Ok(JsonUnionField::JsonNull)
}
Peek::True | Peek::False => {
let value = jiter.known_bool(next_peek)?;
Ok(JsonUnionField::Bool(value))
}
Peek::String => {
let value = jiter.known_str()?;
Ok(JsonUnionField::String(value.to_owned()))
}
Peek::Array => {
let start = jiter.current_index();
jiter.known_skip(next_peek)?;
let array_slice = jiter.slice_to_current(start);
let array_string = std::str::from_utf8(array_slice).map_err(|_| GetError)?;
Ok(JsonUnionField::Array(array_string.to_owned()))
}
Peek::Object => {
let start = jiter.current_index();
jiter.known_skip(next_peek)?;
let object_slice = jiter.slice_to_current(start);
let object_string = std::str::from_utf8(object_slice).map_err(|_| GetError)?;
Ok(JsonUnionField::Object(object_string.to_owned()))
}
_ => match jiter.known_number(next_peek)? {
NumberAny::Int(NumberInt::Int(value)) => Ok(JsonUnionField::Int(value)),
NumberAny::Int(NumberInt::BigInt(_)) => todo!("BigInt not supported yet"),
NumberAny::Float(value) => Ok(JsonUnionField::Float(value)),
},
}
} else {
// we still have more of the path to traverse, recurse
_jiter_json_get(jiter, next_peek, rest)
}
}

fn jiter_array_get(jiter: &mut Jiter, find_key: usize) -> Result<Peek, GetError> {
let mut peek_opt = jiter.known_array()?;

let mut index: usize = 0;
while let Some(peek) = peek_opt {
if index == find_key {
return Ok(peek);
}
index += 1;
peek_opt = jiter.next_array()?;
}
Err(GetError)
}

fn jiter_object_get(jiter: &mut Jiter, find_key: &str) -> Result<Peek, GetError> {
let mut opt_key = jiter.known_object()?;

while let Some(key) = opt_key {
if key == find_key {
let value_peek = jiter.peek()?;
return Ok(value_peek);
}
jiter.next_skip()?;
opt_key = jiter.next_key()?;
}
Err(GetError)
}
33 changes: 14 additions & 19 deletions src/json_obj_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ use arrow_schema::DataType::{LargeUtf8, Utf8};
use datafusion_common::arrow::array::{as_string_array, ArrayRef, BooleanArray};
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Jiter;
use jiter::{Jiter, JiterResult};
use std::any::Any;
use std::sync::Arc;

make_udf_function!(
JsonObjContains,
json_obj_contains,
json_data key, // arg name
"Does the string exist as a top-level key within the JSON value?", // doc
json_obj_contains_udf // internal function name
"Does the string exist as a top-level key within the JSON value?"
);

#[derive(Debug)]
Expand Down Expand Up @@ -48,13 +47,13 @@ impl ScalarUDFImpl for JsonObjContains {
match arg_types[0] {
Utf8 | LargeUtf8 => Ok(DataType::Boolean),
_ => {
plan_err!("The json_obj_contains function can only accept Utf8 or LargeUtf8.")
plan_err!("The `json_obj_contains` function can only accept Utf8 or LargeUtf8.")
}
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let json_haystack = match &args[0] {
let json_data = match &args[0] {
ColumnarValue::Array(array) => as_string_array(array),
ColumnarValue::Scalar(_) => {
return exec_err!("json_obj_contains first argument: unexpected argument type, expected string array")
Expand All @@ -66,9 +65,9 @@ impl ScalarUDFImpl for JsonObjContains {
_ => return exec_err!("json_obj_contains second argument: unexpected argument type, expected string"),
};

let array = json_haystack
let array = json_data
.iter()
.map(|opt_json| opt_json.map(|json| jiter_json_contains(json.as_bytes(), &needle)))
.map(|opt_json| opt_json.map(|json| jiter_json_contains(json.as_bytes(), &needle).unwrap_or(false)))
.collect::<BooleanArray>();

Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
Expand All @@ -79,25 +78,21 @@ impl ScalarUDFImpl for JsonObjContains {
}
}

fn jiter_json_contains(json_data: &[u8], expected_key: &str) -> bool {
fn jiter_json_contains(json_data: &[u8], expected_key: &str) -> JiterResult<bool> {
let mut jiter = Jiter::new(json_data, false);
let Ok(Some(first_key)) = jiter.next_object() else {
return false;
let Some(first_key) = jiter.next_object()? else {
return Ok(false);
};

if first_key == expected_key {
return true;
}
if jiter.next_skip().is_err() {
return false;
return Ok(true);
}
jiter.next_skip()?;
while let Ok(Some(key)) = jiter.next_key() {
if key == expected_key {
return true;
}
if jiter.next_skip().is_err() {
return false;
return Ok(true);
}
jiter.next_skip()?;
}
false
Ok(false)
}
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@ use datafusion_expr::ScalarUDF;
use log::debug;
use std::sync::Arc;

mod json_get;
mod json_obj_contains;
mod macros;
mod rewrite;
mod union;

pub mod functions {
pub use crate::json_get::json_get;
pub use crate::json_obj_contains::json_obj_contains;
}

/// Register all JSON UDFs
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<ScalarUDF>> = vec![json_obj_contains::json_obj_contains_udf()];
let functions: Vec<Arc<ScalarUDF>> = vec![json_obj_contains::json_obj_contains_udf(), json_get::json_get_udf()];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
if let Some(existing_udf) = existing_udf {
debug!("Overwrite existing UDF: {}", existing_udf.name());
}
Ok(()) as Result<()>
})?;
registry.register_function_rewrite(Arc::new(rewrite::JsonFunctionRewriter))?;

Ok(())
}
Loading