Skip to content

Commit bfccb5f

Browse files
authored
filter kernel should work with UnionArray (#1412)
1 parent f0646f8 commit bfccb5f

4 files changed

Lines changed: 334 additions & 5 deletions

File tree

arrow/src/array/data.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,16 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
194194
MutableBuffer::new(capacity * mem::size_of::<u8>()),
195195
empty_buffer,
196196
],
197-
DataType::Union(_, _) => unimplemented!(),
197+
DataType::Union(_, mode) => {
198+
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
199+
match mode {
200+
UnionMode::Sparse => [type_ids, empty_buffer],
201+
UnionMode::Dense => {
202+
let offsets = MutableBuffer::new(capacity * mem::size_of::<i32>());
203+
[type_ids, offsets]
204+
}
205+
}
206+
}
198207
}
199208
}
200209

@@ -210,7 +219,8 @@ pub(crate) fn into_buffers(
210219
DataType::Utf8
211220
| DataType::Binary
212221
| DataType::LargeUtf8
213-
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
222+
| DataType::LargeBinary
223+
| DataType::Union(_, _) => vec![buffer1.into(), buffer2.into()],
214224
_ => vec![buffer1.into()],
215225
}
216226
}
@@ -559,7 +569,10 @@ impl ArrayData {
559569
DataType::Map(field, _) => {
560570
vec![Self::new_empty(field.data_type())]
561571
}
562-
DataType::Union(_, _) => unimplemented!(),
572+
DataType::Union(fields, _) => fields
573+
.iter()
574+
.map(|field| Self::new_empty(field.data_type()))
575+
.collect(),
563576
DataType::Dictionary(_, data_type) => {
564577
vec![Self::new_empty(data_type)]
565578
}

arrow/src/array/transform/mod.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ mod list;
3535
mod null;
3636
mod primitive;
3737
mod structure;
38+
mod union;
3839
mod utils;
3940
mod variable_size;
4041

@@ -272,9 +273,12 @@ fn build_extend(array: &ArrayData) -> Extend {
272273
DataType::Struct(_) => structure::build_extend(array),
273274
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
274275
DataType::Float16 => primitive::build_extend::<f16>(array),
276+
DataType::Union(_, mode) => match mode {
277+
UnionMode::Sparse => union::build_extend_sparse(array),
278+
UnionMode::Dense => union::build_extend_dense(array),
279+
},
275280
/*
276281
DataType::FixedSizeList(_, _) => {}
277-
DataType::Union(_) => {}
278282
*/
279283
ty => todo!(
280284
"Take and filter operations still not supported for this datatype: `{:?}`",
@@ -326,9 +330,12 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
326330
DataType::Struct(_) => structure::extend_nulls,
327331
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
328332
DataType::Float16 => primitive::extend_nulls::<f16>,
333+
DataType::Union(_, mode) => match mode {
334+
UnionMode::Sparse => union::extend_nulls_sparse,
335+
UnionMode::Dense => union::extend_nulls_dense,
336+
},
329337
/*
330338
DataType::FixedSizeList(_, _) => {}
331-
DataType::Union(_) => {}
332339
*/
333340
ty => todo!(
334341
"Take and filter operations still not supported for this datatype: `{:?}`",
@@ -522,6 +529,15 @@ impl<'a> MutableArrayData<'a> {
522529
})
523530
.collect::<Vec<_>>(),
524531
},
532+
DataType::Union(fields, _) => (0..fields.len())
533+
.map(|i| {
534+
let child_arrays = arrays
535+
.iter()
536+
.map(|array| &array.child_data()[i])
537+
.collect::<Vec<_>>();
538+
MutableArrayData::new(child_arrays, use_nulls, array_capacity)
539+
})
540+
.collect::<Vec<_>>(),
525541
ty => {
526542
todo!("Take and filter operations still not supported for this datatype: `{:?}`", ty)
527543
}

arrow/src/array/transform/union.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::array::ArrayData;
19+
20+
use super::{Extend, _MutableArrayData};
21+
22+
pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend {
23+
let type_ids = array.buffer::<i8>(0);
24+
25+
if array.null_count() == 0 {
26+
Box::new(
27+
move |mutable: &mut _MutableArrayData,
28+
index: usize,
29+
start: usize,
30+
len: usize| {
31+
// extends type_ids
32+
mutable
33+
.buffer1
34+
.extend_from_slice(&type_ids[start..start + len]);
35+
36+
mutable
37+
.child_data
38+
.iter_mut()
39+
.for_each(|child| child.extend(index, start, start + len))
40+
},
41+
)
42+
} else {
43+
Box::new(
44+
move |mutable: &mut _MutableArrayData,
45+
index: usize,
46+
start: usize,
47+
len: usize| {
48+
// extends type_ids
49+
mutable
50+
.buffer1
51+
.extend_from_slice(&type_ids[start..start + len]);
52+
53+
(start..start + len).for_each(|i| {
54+
if array.is_valid(i) {
55+
mutable
56+
.child_data
57+
.iter_mut()
58+
.for_each(|child| child.extend(index, i, i + 1))
59+
} else {
60+
mutable
61+
.child_data
62+
.iter_mut()
63+
.for_each(|child| child.extend_nulls(1))
64+
}
65+
})
66+
},
67+
)
68+
}
69+
}
70+
71+
pub(super) fn build_extend_dense(array: &ArrayData) -> Extend {
72+
let type_ids = array.buffer::<i8>(0);
73+
let offsets = array.buffer::<i32>(1);
74+
75+
if array.null_count() == 0 {
76+
Box::new(
77+
move |mutable: &mut _MutableArrayData,
78+
index: usize,
79+
start: usize,
80+
len: usize| {
81+
// extends type_ids
82+
mutable
83+
.buffer1
84+
.extend_from_slice(&type_ids[start..start + len]);
85+
// extends offsets
86+
mutable
87+
.buffer2
88+
.extend_from_slice(&offsets[start..start + len]);
89+
90+
(start..start + len).for_each(|i| {
91+
let type_id = type_ids[i] as usize;
92+
let offset_start = offsets[start] as usize;
93+
94+
mutable.child_data[type_id].extend(
95+
index,
96+
offset_start,
97+
offset_start + 1,
98+
)
99+
})
100+
},
101+
)
102+
} else {
103+
Box::new(
104+
move |mutable: &mut _MutableArrayData,
105+
index: usize,
106+
start: usize,
107+
len: usize| {
108+
// extends type_ids
109+
mutable
110+
.buffer1
111+
.extend_from_slice(&type_ids[start..start + len]);
112+
// extends offsets
113+
mutable
114+
.buffer2
115+
.extend_from_slice(&offsets[start..start + len]);
116+
117+
(start..start + len).for_each(|i| {
118+
let type_id = type_ids[i] as usize;
119+
let offset_start = offsets[start] as usize;
120+
121+
if array.is_valid(i) {
122+
mutable.child_data[type_id].extend(
123+
index,
124+
offset_start,
125+
offset_start + 1,
126+
)
127+
} else {
128+
mutable.child_data[type_id].extend_nulls(1)
129+
}
130+
})
131+
},
132+
)
133+
}
134+
}
135+
136+
pub(super) fn extend_nulls_dense(mutable: &mut _MutableArrayData, len: usize) {
137+
let mut count: usize = 0;
138+
let num = len / mutable.child_data.len();
139+
mutable
140+
.child_data
141+
.iter_mut()
142+
.enumerate()
143+
.for_each(|(idx, child)| {
144+
let n = if count + num > len { len - count } else { num };
145+
count += n;
146+
mutable
147+
.buffer1
148+
.extend_from_slice(vec![idx as i8; n].as_slice());
149+
mutable
150+
.buffer2
151+
.extend_from_slice(vec![child.len() as i32; n].as_slice());
152+
child.extend_nulls(n)
153+
})
154+
}
155+
156+
pub(super) fn extend_nulls_sparse(mutable: &mut _MutableArrayData, len: usize) {
157+
mutable
158+
.child_data
159+
.iter_mut()
160+
.for_each(|child| child.extend_nulls(len))
161+
}

arrow/src/compute/kernels/filter.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,4 +1521,143 @@ mod tests {
15211521

15221522
assert_eq!(&expected, &got);
15231523
}
1524+
1525+
fn test_filter_union_array(array: UnionArray) {
1526+
let filter_array = BooleanArray::from(vec![true, false, false]);
1527+
let c = filter(&array, &filter_array).unwrap();
1528+
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1529+
1530+
let mut builder = UnionBuilder::new_dense(1);
1531+
builder.append::<Int32Type>("A", 1).unwrap();
1532+
let expected_array = builder.build().unwrap();
1533+
1534+
compare_union_arrays(filtered, &expected_array);
1535+
1536+
let filter_array = BooleanArray::from(vec![true, false, true]);
1537+
let c = filter(&array, &filter_array).unwrap();
1538+
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1539+
1540+
let mut builder = UnionBuilder::new_dense(2);
1541+
builder.append::<Int32Type>("A", 1).unwrap();
1542+
builder.append::<Int32Type>("A", 34).unwrap();
1543+
let expected_array = builder.build().unwrap();
1544+
1545+
compare_union_arrays(filtered, &expected_array);
1546+
1547+
let filter_array = BooleanArray::from(vec![true, true, false]);
1548+
let c = filter(&array, &filter_array).unwrap();
1549+
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1550+
1551+
let mut builder = UnionBuilder::new_dense(2);
1552+
builder.append::<Int32Type>("A", 1).unwrap();
1553+
builder.append::<Float64Type>("B", 3.2).unwrap();
1554+
let expected_array = builder.build().unwrap();
1555+
1556+
compare_union_arrays(filtered, &expected_array);
1557+
}
1558+
1559+
#[test]
1560+
fn test_filter_union_array_dense() {
1561+
let mut builder = UnionBuilder::new_dense(3);
1562+
builder.append::<Int32Type>("A", 1).unwrap();
1563+
builder.append::<Float64Type>("B", 3.2).unwrap();
1564+
builder.append::<Int32Type>("A", 34).unwrap();
1565+
let array = builder.build().unwrap();
1566+
1567+
test_filter_union_array(array);
1568+
}
1569+
1570+
#[test]
1571+
fn test_filter_union_array_dense_with_nulls() {
1572+
let mut builder = UnionBuilder::new_dense(4);
1573+
builder.append::<Int32Type>("A", 1).unwrap();
1574+
builder.append::<Float64Type>("B", 3.2).unwrap();
1575+
builder.append_null().unwrap();
1576+
builder.append::<Int32Type>("A", 34).unwrap();
1577+
let array = builder.build().unwrap();
1578+
1579+
let filter_array = BooleanArray::from(vec![true, false, true, false]);
1580+
let c = filter(&array, &filter_array).unwrap();
1581+
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1582+
1583+
let mut builder = UnionBuilder::new_dense(1);
1584+
builder.append::<Int32Type>("A", 1).unwrap();
1585+
builder.append_null().unwrap();
1586+
let expected_array = builder.build().unwrap();
1587+
1588+
compare_union_arrays(filtered, &expected_array);
1589+
}
1590+
1591+
#[test]
1592+
fn test_filter_union_array_sparse() {
1593+
let mut builder = UnionBuilder::new_sparse(3);
1594+
builder.append::<Int32Type>("A", 1).unwrap();
1595+
builder.append::<Float64Type>("B", 3.2).unwrap();
1596+
builder.append::<Int32Type>("A", 34).unwrap();
1597+
let array = builder.build().unwrap();
1598+
1599+
test_filter_union_array(array);
1600+
}
1601+
1602+
#[test]
1603+
fn test_filter_union_array_sparse_with_nulls() {
1604+
let mut builder = UnionBuilder::new_sparse(4);
1605+
builder.append::<Int32Type>("A", 1).unwrap();
1606+
builder.append::<Float64Type>("B", 3.2).unwrap();
1607+
builder.append_null().unwrap();
1608+
builder.append::<Int32Type>("A", 34).unwrap();
1609+
let array = builder.build().unwrap();
1610+
1611+
let filter_array = BooleanArray::from(vec![true, false, true, false]);
1612+
let c = filter(&array, &filter_array).unwrap();
1613+
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1614+
1615+
let mut builder = UnionBuilder::new_dense(1);
1616+
builder.append::<Int32Type>("A", 1).unwrap();
1617+
builder.append_null().unwrap();
1618+
let expected_array = builder.build().unwrap();
1619+
1620+
compare_union_arrays(filtered, &expected_array);
1621+
}
1622+
1623+
fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1624+
assert_eq!(union1.len(), union2.len());
1625+
1626+
for i in 0..union1.len() {
1627+
let type_id = union1.type_id(i);
1628+
1629+
let slot1 = union1.value(i);
1630+
let slot2 = union2.value(i);
1631+
1632+
assert_eq!(union1.is_null(i), union2.is_null(i));
1633+
1634+
if !union1.is_null(i) && !union2.is_null(i) {
1635+
match type_id {
1636+
0 => {
1637+
let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1638+
assert_eq!(slot1.len(), 1);
1639+
let value1 = slot1.value(0);
1640+
1641+
let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1642+
assert_eq!(slot2.len(), 1);
1643+
let value2 = slot2.value(0);
1644+
assert_eq!(value1, value2);
1645+
}
1646+
1 => {
1647+
let slot1 =
1648+
slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1649+
assert_eq!(slot1.len(), 1);
1650+
let value1 = slot1.value(0);
1651+
1652+
let slot2 =
1653+
slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1654+
assert_eq!(slot2.len(), 1);
1655+
let value2 = slot2.value(0);
1656+
assert_eq!(value1, value2);
1657+
}
1658+
_ => unreachable!(),
1659+
}
1660+
}
1661+
}
1662+
}
15241663
}

0 commit comments

Comments
 (0)