Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
Expand All @@ -27,7 +28,7 @@ use arrow::array::{
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility};
use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion::{
Expand All @@ -37,7 +38,7 @@ use datafusion::{
use datafusion_catalog::CatalogProvider;
use datafusion_catalog::{memory::MemoryCatalogProvider, memory::MemorySchemaProvider};
use datafusion_common::cast::as_float64_array;
use datafusion_common::DataFusionError;
use datafusion_common::{plan_err, Result, DataFusionError};

use async_trait::async_trait;
use datafusion::catalog::Session;
Expand Down Expand Up @@ -113,6 +114,10 @@ impl TestContext {
info!("Registering metadata table tables");
register_metadata_tables(test_ctx.session_ctx()).await;
}
"case.slt" => {
info!("Registering case conversaion");
register_to_large_list(test_ctx.session_ctx());
}
_ => {
info!("Using default SessionContext");
}
Expand Down Expand Up @@ -214,7 +219,7 @@ pub async fn register_temp_table(ctx: &SessionContext) {

#[async_trait]
impl TableProvider for TestTable {
fn as_any(&self) -> &dyn std::any::Any {
fn as_any(&self) -> &dyn Any {
self
}

Expand Down Expand Up @@ -402,3 +407,58 @@ fn create_example_udf() -> ScalarUDF {
adder,
)
}

fn register_to_large_list(ctx: &SessionContext) {
ctx.register_udf(ScalarUDF::from(ToLargeList::new()))
}

/// `to_large_list()` converts its argument `ListArray` to a `LargeListArray`
#[derive(Debug)]
struct ToLargeList {
signature: Signature
}

impl ToLargeList {
fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable)
}
}
}
impl ScalarUDFImpl for ToLargeList {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"to_large_list"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return plan_err!("to_large_list() takes exactly one argument");
}
let DataType::List(field) = &arg_types[0] else {
return plan_err!("to_large_list() takes a list as its argument");
};

Ok(DataType::LargeList(Arc::clone(field)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
match &args.args[0] {
ColumnarValue::Array(array) => {
let cast_array = arrow::compute::cast(array, args.return_type)?;
Ok(ColumnarValue::Array(cast_array))
}
ColumnarValue::Scalar(scalar) => {
let cast_scalar = scalar.cast_to(args.return_type)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
}
}
81 changes: 81 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,84 @@ NULL NULL false

statement ok
drop table foo

####
#### Case with structs with subfields that need to be coerced
#####

statement ok
create table t as values
(
true, -- column1 boolean (so the case isn't constant folded)
[{ 'foo': 'bar' }], -- column2 has List of Struct w/ Utf8
[{ 'foo': arrow_cast('bar', 'Utf8View') }] -- column3 has List of Struct w/ Utf8View
)

query B?
select column1, column2 from t;
----
true [{foo: bar}]

query TT
select arrow_typeof(column1), arrow_typeof(column2) from t;
----
Boolean List(Field { name: "item", data_type: Struct([Field { name: "foo", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

# Force coercion of column2 to Utf8View
query ??
select
case when column1 then column2 else column3 end,
case when not column1 then column2 else column3 end
from t;
----
[{c0: bar}] [{c0: bar}]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is clearly wrong -- it should have 'foo' as the field name


query T
select
arrow_typeof(case when column1 then column2 else column3 end)
from t;
----
List(Field { name: "item", data_type: Struct([Field { name: "c0", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })


# Force coercion of column2 to a LargeList
# this requires that the List is coerced as well as the Struct fields within the list
query ?
select
case when column1 then to_large_list(column2) else column3 end
from t;
----
[{c0: bar}]

query T
select
arrow_typeof(case when column1 then to_large_list(column2) else column3 end)
from t;
----
LargeList(Field { name: "item", data_type: Struct([Field { name: "c0", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

# Force coercion of column3 to a LargeList
# this requires that the List is coerced as well as the Struct fields within the list
query ?
select
case when column1 then column2 else to_large_list(column3) end
from t;
----
[{c0: bar}]

# Force coercion of column2 to a LargeList with inner coersion
# this requires that the List is coerced as well as the Struct fields within the list
query ?
select
case
when column1 IS TRUE then column2
when column1 IS FALSE then to_large_list(column3)
else NULL
end
from t;
----
[{c0: bar}]


statement ok
drop table t;
Loading