Skip to content

Commit e84fe20

Browse files
authored
Speed up filter_record_batch with one array (#637)
* Speed up filter_record_batch with one array * Don't into()
1 parent a08b939 commit e84fe20

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

arrow/benches/filter_kernels.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
// under the License.
1717
extern crate arrow;
1818

19-
use arrow::compute::Filter;
19+
use std::sync::Arc;
20+
21+
use arrow::compute::{filter_record_batch, Filter};
22+
use arrow::record_batch::RecordBatch;
2023
use arrow::util::bench_util::*;
2124

2225
use arrow::array::*;
2326
use arrow::compute::{build_filter, filter};
24-
use arrow::datatypes::{Float32Type, UInt8Type};
27+
use arrow::datatypes::{Field, Float32Type, Schema, UInt8Type};
2528

2629
use criterion::{criterion_group, criterion_main, Criterion};
2730

@@ -100,6 +103,18 @@ fn add_benchmark(c: &mut Criterion) {
100103
c.bench_function("filter context string low selectivity", |b| {
101104
b.iter(|| bench_built_filter(&sparse_filter, &data_array))
102105
});
106+
107+
let data_array = create_primitive_array::<Float32Type>(size, 0.0);
108+
109+
let field = Field::new("c1", data_array.data_type().clone(), true);
110+
let schema = Schema::new(vec![field]);
111+
112+
let batch =
113+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data_array)]).unwrap();
114+
115+
c.bench_function("filter single record batch", |b| {
116+
b.iter(|| filter_record_batch(&batch, &filter_array))
117+
});
103118
}
104119

105120
criterion_group!(benches, add_benchmark);

arrow/src/compute/kernels/filter.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,21 @@ pub fn filter_record_batch(
288288
return filter_record_batch(record_batch, &predicate);
289289
}
290290

291-
let filter = build_filter(predicate)?;
292-
let filtered_arrays = record_batch
293-
.columns()
294-
.iter()
295-
.map(|a| make_array(filter(a.data())))
296-
.collect();
291+
let num_colums = record_batch.columns().len();
292+
293+
let filtered_arrays = match num_colums {
294+
1 => {
295+
vec![filter(record_batch.columns()[0].as_ref(), predicate)?]
296+
}
297+
_ => {
298+
let filter = build_filter(predicate)?;
299+
record_batch
300+
.columns()
301+
.iter()
302+
.map(|a| make_array(filter(a.data())))
303+
.collect()
304+
}
305+
};
297306
RecordBatch::try_new(record_batch.schema(), filtered_arrays)
298307
}
299308

0 commit comments

Comments
 (0)