@@ -29,8 +29,9 @@ use arrow::{
2929 } ,
3030 record_batch:: RecordBatch ,
3131} ;
32- use arrow_array:: Float32Array ;
33- use arrow_schema:: ArrowError ;
32+ use arrow_array:: { Array , Float32Array , Float64Array , UnionArray } ;
33+ use arrow_buffer:: ScalarBuffer ;
34+ use arrow_schema:: { ArrowError , UnionFields , UnionMode } ;
3435use datafusion_functions_aggregate:: count:: count_udaf;
3536use object_store:: local:: LocalFileSystem ;
3637use std:: fs;
@@ -2195,3 +2196,163 @@ async fn write_parquet_results() -> Result<()> {
21952196
21962197 Ok ( ( ) )
21972198}
2199+
2200+ fn union_fields ( ) -> UnionFields {
2201+ [
2202+ ( 0 , Arc :: new ( Field :: new ( "A" , DataType :: Int32 , true ) ) ) ,
2203+ ( 1 , Arc :: new ( Field :: new ( "B" , DataType :: Float64 , true ) ) ) ,
2204+ ( 2 , Arc :: new ( Field :: new ( "C" , DataType :: Utf8 , true ) ) ) ,
2205+ ]
2206+ . into_iter ( )
2207+ . collect ( )
2208+ }
2209+
2210+ #[ tokio:: test]
2211+ async fn sparse_union_is_null ( ) {
2212+ // union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}]
2213+ let int_array = Int32Array :: from ( vec ! [ Some ( 1 ) , None , None , None , None , None ] ) ;
2214+ let float_array = Float64Array :: from ( vec ! [ None , None , Some ( 3.2 ) , None , None , None ] ) ;
2215+ let str_array = StringArray :: from ( vec ! [ None , None , None , None , Some ( "a" ) , None ] ) ;
2216+ let type_ids = [ 0 , 0 , 1 , 1 , 2 , 2 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
2217+
2218+ let children = vec ! [
2219+ Arc :: new( int_array) as Arc <dyn Array >,
2220+ Arc :: new( float_array) ,
2221+ Arc :: new( str_array) ,
2222+ ] ;
2223+
2224+ let array = UnionArray :: try_new ( union_fields ( ) , type_ids, None , children) . unwrap ( ) ;
2225+
2226+ let field = Field :: new (
2227+ "my_union" ,
2228+ DataType :: Union ( union_fields ( ) , UnionMode :: Sparse ) ,
2229+ true ,
2230+ ) ;
2231+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
2232+
2233+ let batch = RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( ) ;
2234+
2235+ let ctx = SessionContext :: new ( ) ;
2236+
2237+ ctx. register_batch ( "union_batch" , batch) . unwrap ( ) ;
2238+
2239+ let df = ctx. table ( "union_batch" ) . await . unwrap ( ) ;
2240+
2241+ // view_all
2242+ let expected = [
2243+ "+----------+" ,
2244+ "| my_union |" ,
2245+ "+----------+" ,
2246+ "| {A=1} |" ,
2247+ "| {A=} |" ,
2248+ "| {B=3.2} |" ,
2249+ "| {B=} |" ,
2250+ "| {C=a} |" ,
2251+ "| {C=} |" ,
2252+ "+----------+" ,
2253+ ] ;
2254+ assert_batches_sorted_eq ! ( expected, & df. clone( ) . collect( ) . await . unwrap( ) ) ;
2255+
2256+ // filter where is null
2257+ let result_df = df. clone ( ) . filter ( col ( "my_union" ) . is_null ( ) ) . unwrap ( ) ;
2258+ let expected = [
2259+ "+----------+" ,
2260+ "| my_union |" ,
2261+ "+----------+" ,
2262+ "| {A=} |" ,
2263+ "| {B=} |" ,
2264+ "| {C=} |" ,
2265+ "+----------+" ,
2266+ ] ;
2267+ assert_batches_sorted_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2268+
2269+ // filter where is not null
2270+ let result_df = df. filter ( col ( "my_union" ) . is_not_null ( ) ) . unwrap ( ) ;
2271+ let expected = [
2272+ "+----------+" ,
2273+ "| my_union |" ,
2274+ "+----------+" ,
2275+ "| {A=1} |" ,
2276+ "| {B=3.2} |" ,
2277+ "| {C=a} |" ,
2278+ "+----------+" ,
2279+ ] ;
2280+ assert_batches_sorted_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2281+ }
2282+
2283+ #[ tokio:: test]
2284+ async fn dense_union_is_null ( ) {
2285+ // union of [{A=1}, null, {B=3.2}, {A=34}]
2286+ let int_array = Int32Array :: from ( vec ! [ Some ( 1 ) , None ] ) ;
2287+ let float_array = Float64Array :: from ( vec ! [ Some ( 3.2 ) , None ] ) ;
2288+ let str_array = StringArray :: from ( vec ! [ Some ( "a" ) , None ] ) ;
2289+ let type_ids = [ 0 , 0 , 1 , 1 , 2 , 2 ] . into_iter ( ) . collect :: < ScalarBuffer < i8 > > ( ) ;
2290+ let offsets = [ 0 , 1 , 0 , 1 , 0 , 1 ]
2291+ . into_iter ( )
2292+ . collect :: < ScalarBuffer < i32 > > ( ) ;
2293+
2294+ let children = vec ! [
2295+ Arc :: new( int_array) as Arc <dyn Array >,
2296+ Arc :: new( float_array) ,
2297+ Arc :: new( str_array) ,
2298+ ] ;
2299+
2300+ let array =
2301+ UnionArray :: try_new ( union_fields ( ) , type_ids, Some ( offsets) , children) . unwrap ( ) ;
2302+
2303+ let field = Field :: new (
2304+ "my_union" ,
2305+ DataType :: Union ( union_fields ( ) , UnionMode :: Dense ) ,
2306+ true ,
2307+ ) ;
2308+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
2309+
2310+ let batch = RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( ) ;
2311+
2312+ let ctx = SessionContext :: new ( ) ;
2313+
2314+ ctx. register_batch ( "union_batch" , batch) . unwrap ( ) ;
2315+
2316+ let df = ctx. table ( "union_batch" ) . await . unwrap ( ) ;
2317+
2318+ // view_all
2319+ let expected = [
2320+ "+----------+" ,
2321+ "| my_union |" ,
2322+ "+----------+" ,
2323+ "| {A=1} |" ,
2324+ "| {A=} |" ,
2325+ "| {B=3.2} |" ,
2326+ "| {B=} |" ,
2327+ "| {C=a} |" ,
2328+ "| {C=} |" ,
2329+ "+----------+" ,
2330+ ] ;
2331+ assert_batches_sorted_eq ! ( expected, & df. clone( ) . collect( ) . await . unwrap( ) ) ;
2332+
2333+ // filter where is null
2334+ let result_df = df. clone ( ) . filter ( col ( "my_union" ) . is_null ( ) ) . unwrap ( ) ;
2335+ let expected = [
2336+ "+----------+" ,
2337+ "| my_union |" ,
2338+ "+----------+" ,
2339+ "| {A=} |" ,
2340+ "| {B=} |" ,
2341+ "| {C=} |" ,
2342+ "+----------+" ,
2343+ ] ;
2344+ assert_batches_sorted_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2345+
2346+ // filter where is not null
2347+ let result_df = df. filter ( col ( "my_union" ) . is_not_null ( ) ) . unwrap ( ) ;
2348+ let expected = [
2349+ "+----------+" ,
2350+ "| my_union |" ,
2351+ "+----------+" ,
2352+ "| {A=1} |" ,
2353+ "| {B=3.2} |" ,
2354+ "| {C=a} |" ,
2355+ "+----------+" ,
2356+ ] ;
2357+ assert_batches_sorted_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2358+ }
0 commit comments