Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
mod arrow_cast;
mod arrowtypeof;
mod getfield;
mod named_struct;
mod nullif;
mod nvl;
mod nvl2;
Expand All @@ -32,6 +33,7 @@ make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);
make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof);
make_udf_function!(r#struct::StructFunc, STRUCT, r#struct);
make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct);
make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);

// Export the functions out of this package, both as expr_fn as well as a list of functions
Expand All @@ -42,5 +44,6 @@ export_functions!(
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
(r#struct, args, "Returns a struct with the given arguments"),
(named_struct, args, "Returns a struct with the given names and arguments pairs"),
(get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct")
);
196 changes: 196 additions & 0 deletions datafusion/functions/src/core/named_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::StructArray;
use arrow::datatypes::{DataType, Field, Fields};
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

/// put values in a struct array.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding a new function, I think we could instead change the existing struct UDF to support explicit named fields.

Perhaps we can do that as a follow on PR?

fn named_struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
// do not accept 0 arguments.
if args.is_empty() {
return exec_err!("named_struct requires at least one pair of arguments, got 0 instead");
}

if args.len() % 2 != 0 {
return exec_err!("named_struct requires an even number of arguments, got {} instead", args.len());
}

let (names, values): (Vec<_>, Vec<_>) = args
.chunks_exact(2)
.enumerate()
.map(|(i, chunk)| {

let name_column = &chunk[0];

let name = match name_column {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar,
_ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

};

Ok((name, chunk[1].clone()))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();

let arrays = ColumnarValue::values_to_arrays(&values)?;

let fields = names.into_iter()
.zip(arrays)
.map(|(name, value)| {
(
Arc::new(Field::new(
name,
value.data_type().clone(),
true,
)),
value,
)
})
.collect::<Vec<_>>();

Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields))))
}

#[derive(Debug)]
pub(super) struct NamedStructFunc {
signature: Signature,
}

impl NamedStructFunc {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for NamedStructFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"named_struct"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("named_struct: return_type called instead of return_type_from_exprs")
}

fn return_type_from_exprs(
&self,
args: &[datafusion_expr::Expr],
schema: &dyn datafusion_common::ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
// do not accept 0 arguments.
if args.is_empty() {
return exec_err!("named_struct requires at least one pair of arguments, got 0 instead");
}

if args.len() % 2 != 0 {
return exec_err!("named_struct requires an even number of arguments, got {} instead", args.len());
}

let return_fields = args
.chunks_exact(2)
.enumerate()
.map(|(i, chunk)| {
let name = &chunk[0];
let value = &chunk[1];

if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name {
Ok(Field::new(name, value.get_type(schema)?, true))
} else {
exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2)
}
})
.collect::<Result<Vec<Field>>>()?;
Ok(DataType::Struct(Fields::from(return_fields)))
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
named_struct_expr(args)
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int64Array;
use datafusion_common::cast::as_struct_array;
use datafusion_common::ScalarValue;

#[test]
fn test_named_struct() {
// named_struct("first", 1, "second", 2, "third", 3) = {"first": 1, "second": 2, "third": 3}
let args = [
ColumnarValue::Scalar(ScalarValue::Utf8(Some("first".into()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("second".into()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("third".into()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
];
let struc = named_struct_expr(&args)
.expect("failed to initialize function struct")
.into_array(1)
.expect("Failed to convert to array");
let result =
as_struct_array(&struc).expect("failed to initialize function struct");
assert_eq!(
&Int64Array::from(vec![1]),
result
.column_by_name("first")
.unwrap()
.clone()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
);
assert_eq!(
&Int64Array::from(vec![2]),
result
.column_by_name("second")
.unwrap()
.clone()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
);
assert_eq!(
&Int64Array::from(vec![3]),
result
.column_by_name("third")
.unwrap()
.clone()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
);
}
}
45 changes: 36 additions & 9 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Literal, Operator,
TryCast,
};

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
Expand Down Expand Up @@ -604,18 +605,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
let args = values
.into_iter()
.map(|value| {
self.sql_expr_to_logical_expr(value, input_schema, planner_context)
.enumerate()
.map(|(i, value)| {
let args = if let SQLExpr::Named { expr, name } = value {
[
name.value.lit(),
self.sql_expr_to_logical_expr(
*expr,
input_schema,
planner_context,
)?,
]
} else {
[
format!("c{i}").lit(),
self.sql_expr_to_logical_expr(
value,
input_schema,
planner_context,
)?,
]
};

Ok(args)
})
.collect::<Result<Vec<_>>>()?;
let struct_func = self
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

let named_struct_func = self
.context_provider
.get_function_meta("struct")
.get_function_meta("named_struct")
.ok_or_else(|| {
internal_datafusion_err!("Unable to find expected 'struct' function")
})?;
internal_datafusion_err!("Unable to find expected 'named_struct' function")
})?;

Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
struct_func,
named_struct_func,
args,
)))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ query TT
explain select struct(1, 2.3, 'abc');
----
logical_plan
Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc"))
Projection: Struct({c0:1,c1:2.3,c2:abc}) AS named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))
--EmptyRelation
physical_plan
ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))]
ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))]
--PlaceholderRowExec
Loading