diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index f7c9346a8983..154686082458 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; @@ -27,7 +28,7 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; -use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; +use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ @@ -37,7 +38,7 @@ use datafusion::{ use datafusion_catalog::CatalogProvider; use datafusion_catalog::{memory::MemoryCatalogProvider, memory::MemorySchemaProvider}; use datafusion_common::cast::as_float64_array; -use datafusion_common::DataFusionError; +use datafusion_common::{plan_err, Result, DataFusionError}; use async_trait::async_trait; use datafusion::catalog::Session; @@ -113,6 +114,10 @@ impl TestContext { info!("Registering metadata table tables"); register_metadata_tables(test_ctx.session_ctx()).await; } + "case.slt" => { + info!("Registering case conversaion"); + register_to_large_list(test_ctx.session_ctx()); + } _ => { info!("Using default SessionContext"); } @@ -214,7 +219,7 @@ pub async fn register_temp_table(ctx: &SessionContext) { #[async_trait] impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -402,3 +407,58 @@ fn create_example_udf() -> ScalarUDF { adder, ) } + +fn register_to_large_list(ctx: &SessionContext) { + ctx.register_udf(ScalarUDF::from(ToLargeList::new())) +} + +/// `to_large_list()` converts its argument `ListArray` to a `LargeListArray` +#[derive(Debug)] +struct ToLargeList { + signature: Signature +} + +impl ToLargeList { + fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable) + } + } +} +impl ScalarUDFImpl for ToLargeList { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_large_list" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!("to_large_list() takes exactly one argument"); + } + let DataType::List(field) = &arg_types[0] else { + return plan_err!("to_large_list() takes a list as its argument"); + }; + + Ok(DataType::LargeList(Arc::clone(field))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match &args.args[0] { + ColumnarValue::Array(array) => { + let cast_array = arrow::compute::cast(array, args.return_type)?; + Ok(ColumnarValue::Array(cast_array)) + } + ColumnarValue::Scalar(scalar) => { + let cast_scalar = scalar.cast_to(args.return_type)?; + Ok(ColumnarValue::Scalar(cast_scalar)) + } + } + } +} diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index a339c2aa037e..17363f55a49c 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -308,3 +308,84 @@ NULL NULL false statement ok drop table foo + +#### +#### Case with structs with subfields that need to be coerced +##### + +statement ok +create table t as values +( + true, -- column1 boolean (so the case isn't constant folded) + [{ 'foo': 'bar' }], -- column2 has List of Struct w/ Utf8 + [{ 'foo': arrow_cast('bar', 'Utf8View') }] -- column3 has List of Struct w/ Utf8View +) + +query B? +select column1, column2 from t; +---- +true [{foo: bar}] + +query TT +select arrow_typeof(column1), arrow_typeof(column2) from t; +---- +Boolean List(Field { name: "item", data_type: Struct([Field { name: "foo", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# Force coercion of column2 to Utf8View +query ?? +select + case when column1 then column2 else column3 end, + case when not column1 then column2 else column3 end +from t; +---- +[{c0: bar}] [{c0: bar}] + +query T +select + arrow_typeof(case when column1 then column2 else column3 end) +from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "c0", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + +# Force coercion of column2 to a LargeList +# this requires that the List is coerced as well as the Struct fields within the list +query ? +select + case when column1 then to_large_list(column2) else column3 end +from t; +---- +[{c0: bar}] + +query T +select + arrow_typeof(case when column1 then to_large_list(column2) else column3 end) +from t; +---- +LargeList(Field { name: "item", data_type: Struct([Field { name: "c0", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# Force coercion of column3 to a LargeList +# this requires that the List is coerced as well as the Struct fields within the list +query ? +select + case when column1 then column2 else to_large_list(column3) end +from t; +---- +[{c0: bar}] + +# Force coercion of column2 to a LargeList with inner coersion +# this requires that the List is coerced as well as the Struct fields within the list +query ? +select + case + when column1 IS TRUE then column2 + when column1 IS FALSE then to_large_list(column3) + else NULL + end +from t; +---- +[{c0: bar}] + + +statement ok +drop table t;