Skip to content

Commit 0c02cad

Browse files
authored
Support IS NULL and IS NOT NULL on Unions (#11321)
* Demonstrate unions can't be null * add scalar test cases * support "IS NULL" and "IS NOT NULL" on unions * formatting * fix comments from @alamb * fix docstring
1 parent b6281b5 commit 0c02cad

4 files changed

Lines changed: 373 additions & 9 deletions

File tree

datafusion/common/src/scalar/mod.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,10 @@ impl ScalarValue {
14591459
ScalarValue::DurationMillisecond(v) => v.is_none(),
14601460
ScalarValue::DurationMicrosecond(v) => v.is_none(),
14611461
ScalarValue::DurationNanosecond(v) => v.is_none(),
1462-
ScalarValue::Union(v, _, _) => v.is_none(),
1462+
ScalarValue::Union(v, _, _) => match v {
1463+
Some((_, s)) => s.is_null(),
1464+
None => true,
1465+
},
14631466
ScalarValue::Dictionary(_, v) => v.is_null(),
14641467
}
14651468
}
@@ -6514,4 +6517,33 @@ mod tests {
65146517
}
65156518
intervals
65166519
}
6520+
6521+
fn union_fields() -> UnionFields {
6522+
[
6523+
(0, Arc::new(Field::new("A", DataType::Int32, true))),
6524+
(1, Arc::new(Field::new("B", DataType::Float64, true))),
6525+
]
6526+
.into_iter()
6527+
.collect()
6528+
}
6529+
6530+
#[test]
6531+
fn sparse_scalar_union_is_null() {
6532+
let sparse_scalar = ScalarValue::Union(
6533+
Some((0_i8, Box::new(ScalarValue::Int32(None)))),
6534+
union_fields(),
6535+
UnionMode::Sparse,
6536+
);
6537+
assert!(sparse_scalar.is_null());
6538+
}
6539+
6540+
#[test]
6541+
fn dense_scalar_union_is_null() {
6542+
let dense_scalar = ScalarValue::Union(
6543+
Some((0_i8, Box::new(ScalarValue::Int32(None)))),
6544+
union_fields(),
6545+
UnionMode::Dense,
6546+
);
6547+
assert!(dense_scalar.is_null());
6548+
}
65176549
}

datafusion/core/tests/dataframe/mod.rs

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ use arrow::{
2929
},
3030
record_batch::RecordBatch,
3131
};
32-
use arrow_array::Float32Array;
33-
use arrow_schema::ArrowError;
32+
use arrow_array::{Array, Float32Array, Float64Array, UnionArray};
33+
use arrow_buffer::ScalarBuffer;
34+
use arrow_schema::{ArrowError, UnionFields, UnionMode};
3435
use datafusion_functions_aggregate::count::count_udaf;
3536
use object_store::local::LocalFileSystem;
3637
use std::fs;
@@ -2195,3 +2196,163 @@ async fn write_parquet_results() -> Result<()> {
21952196

21962197
Ok(())
21972198
}
2199+
2200+
fn union_fields() -> UnionFields {
2201+
[
2202+
(0, Arc::new(Field::new("A", DataType::Int32, true))),
2203+
(1, Arc::new(Field::new("B", DataType::Float64, true))),
2204+
(2, Arc::new(Field::new("C", DataType::Utf8, true))),
2205+
]
2206+
.into_iter()
2207+
.collect()
2208+
}
2209+
2210+
#[tokio::test]
2211+
async fn sparse_union_is_null() {
2212+
// union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}]
2213+
let int_array = Int32Array::from(vec![Some(1), None, None, None, None, None]);
2214+
let float_array = Float64Array::from(vec![None, None, Some(3.2), None, None, None]);
2215+
let str_array = StringArray::from(vec![None, None, None, None, Some("a"), None]);
2216+
let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::<ScalarBuffer<i8>>();
2217+
2218+
let children = vec![
2219+
Arc::new(int_array) as Arc<dyn Array>,
2220+
Arc::new(float_array),
2221+
Arc::new(str_array),
2222+
];
2223+
2224+
let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2225+
2226+
let field = Field::new(
2227+
"my_union",
2228+
DataType::Union(union_fields(), UnionMode::Sparse),
2229+
true,
2230+
);
2231+
let schema = Arc::new(Schema::new(vec![field]));
2232+
2233+
let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap();
2234+
2235+
let ctx = SessionContext::new();
2236+
2237+
ctx.register_batch("union_batch", batch).unwrap();
2238+
2239+
let df = ctx.table("union_batch").await.unwrap();
2240+
2241+
// view_all
2242+
let expected = [
2243+
"+----------+",
2244+
"| my_union |",
2245+
"+----------+",
2246+
"| {A=1} |",
2247+
"| {A=} |",
2248+
"| {B=3.2} |",
2249+
"| {B=} |",
2250+
"| {C=a} |",
2251+
"| {C=} |",
2252+
"+----------+",
2253+
];
2254+
assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap());
2255+
2256+
// filter where is null
2257+
let result_df = df.clone().filter(col("my_union").is_null()).unwrap();
2258+
let expected = [
2259+
"+----------+",
2260+
"| my_union |",
2261+
"+----------+",
2262+
"| {A=} |",
2263+
"| {B=} |",
2264+
"| {C=} |",
2265+
"+----------+",
2266+
];
2267+
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
2268+
2269+
// filter where is not null
2270+
let result_df = df.filter(col("my_union").is_not_null()).unwrap();
2271+
let expected = [
2272+
"+----------+",
2273+
"| my_union |",
2274+
"+----------+",
2275+
"| {A=1} |",
2276+
"| {B=3.2} |",
2277+
"| {C=a} |",
2278+
"+----------+",
2279+
];
2280+
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
2281+
}
2282+
2283+
#[tokio::test]
2284+
async fn dense_union_is_null() {
2285+
// union of [{A=1}, null, {B=3.2}, {A=34}]
2286+
let int_array = Int32Array::from(vec![Some(1), None]);
2287+
let float_array = Float64Array::from(vec![Some(3.2), None]);
2288+
let str_array = StringArray::from(vec![Some("a"), None]);
2289+
let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::<ScalarBuffer<i8>>();
2290+
let offsets = [0, 1, 0, 1, 0, 1]
2291+
.into_iter()
2292+
.collect::<ScalarBuffer<i32>>();
2293+
2294+
let children = vec![
2295+
Arc::new(int_array) as Arc<dyn Array>,
2296+
Arc::new(float_array),
2297+
Arc::new(str_array),
2298+
];
2299+
2300+
let array =
2301+
UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
2302+
2303+
let field = Field::new(
2304+
"my_union",
2305+
DataType::Union(union_fields(), UnionMode::Dense),
2306+
true,
2307+
);
2308+
let schema = Arc::new(Schema::new(vec![field]));
2309+
2310+
let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap();
2311+
2312+
let ctx = SessionContext::new();
2313+
2314+
ctx.register_batch("union_batch", batch).unwrap();
2315+
2316+
let df = ctx.table("union_batch").await.unwrap();
2317+
2318+
// view_all
2319+
let expected = [
2320+
"+----------+",
2321+
"| my_union |",
2322+
"+----------+",
2323+
"| {A=1} |",
2324+
"| {A=} |",
2325+
"| {B=3.2} |",
2326+
"| {B=} |",
2327+
"| {C=a} |",
2328+
"| {C=} |",
2329+
"+----------+",
2330+
];
2331+
assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap());
2332+
2333+
// filter where is null
2334+
let result_df = df.clone().filter(col("my_union").is_null()).unwrap();
2335+
let expected = [
2336+
"+----------+",
2337+
"| my_union |",
2338+
"+----------+",
2339+
"| {A=} |",
2340+
"| {B=} |",
2341+
"| {C=} |",
2342+
"+----------+",
2343+
];
2344+
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
2345+
2346+
// filter where is not null
2347+
let result_df = df.filter(col("my_union").is_not_null()).unwrap();
2348+
let expected = [
2349+
"+----------+",
2350+
"| my_union |",
2351+
"+----------+",
2352+
"| {A=1} |",
2353+
"| {B=3.2} |",
2354+
"| {C=a} |",
2355+
"+----------+",
2356+
];
2357+
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
2358+
}

datafusion/physical-expr/src/expressions/is_not_null.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ impl PhysicalExpr for IsNotNullExpr {
7373
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
7474
let arg = self.arg.evaluate(batch)?;
7575
match arg {
76-
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new(
77-
compute::is_not_null(array.as_ref())?,
78-
))),
76+
ColumnarValue::Array(array) => {
77+
let is_null = super::is_null::compute_is_null(array)?;
78+
let is_not_null = compute::not(&is_null)?;
79+
Ok(ColumnarValue::Array(Arc::new(is_not_null)))
80+
}
7981
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
8082
ScalarValue::Boolean(Some(!scalar.is_null())),
8183
)),
@@ -120,6 +122,8 @@ mod tests {
120122
array::{BooleanArray, StringArray},
121123
datatypes::*,
122124
};
125+
use arrow_array::{Array, Float64Array, Int32Array, UnionArray};
126+
use arrow_buffer::ScalarBuffer;
123127
use datafusion_common::cast::as_boolean_array;
124128

125129
#[test]
@@ -143,4 +147,48 @@ mod tests {
143147

144148
Ok(())
145149
}
150+
151+
#[test]
152+
fn union_is_not_null_op() {
153+
// union of [{A=1}, {A=}, {B=1.1}, {B=1.2}, {B=}]
154+
let int_array = Int32Array::from(vec![Some(1), None, None, None, None]);
155+
let float_array =
156+
Float64Array::from(vec![None, None, Some(1.1), Some(1.2), None]);
157+
let type_ids = [0, 0, 1, 1, 1].into_iter().collect::<ScalarBuffer<i8>>();
158+
159+
let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
160+
161+
let union_fields: UnionFields = [
162+
(0, Arc::new(Field::new("A", DataType::Int32, true))),
163+
(1, Arc::new(Field::new("B", DataType::Float64, true))),
164+
]
165+
.into_iter()
166+
.collect();
167+
168+
let array =
169+
UnionArray::try_new(union_fields.clone(), type_ids, None, children).unwrap();
170+
171+
let field = Field::new(
172+
"my_union",
173+
DataType::Union(union_fields, UnionMode::Sparse),
174+
true,
175+
);
176+
177+
let schema = Schema::new(vec![field]);
178+
let expr = is_not_null(col("my_union", &schema).unwrap()).unwrap();
179+
let batch =
180+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap();
181+
182+
// expression: "a is not null"
183+
let actual = expr
184+
.evaluate(&batch)
185+
.unwrap()
186+
.into_array(batch.num_rows())
187+
.expect("Failed to convert to array");
188+
let actual = as_boolean_array(&actual).unwrap();
189+
190+
let expected = &BooleanArray::from(vec![true, false, true, true, false]);
191+
192+
assert_eq!(expected, actual);
193+
}
146194
}

0 commit comments

Comments
 (0)