Skip to content

Commit 510f02f

Browse files
authored
Speed up bound checking in take (#281)
* WIP improve take performance * WIP * Bound checking speed * Simplify * fmt * Improve formatting
1 parent aba044f commit 510f02f

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

arrow/benches/take_kernels.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use rand::Rng;
2323

2424
extern crate arrow;
2525

26-
use arrow::compute::take;
26+
use arrow::compute::{take, TakeOptions};
2727
use arrow::datatypes::*;
2828
use arrow::util::test_util::seedable_rng;
2929
use arrow::{array::*, util::bench_util::*};
@@ -46,6 +46,12 @@ fn bench_take(values: &dyn Array, indices: &UInt32Array) {
4646
criterion::black_box(take(values, &indices, None).unwrap());
4747
}
4848

49+
fn bench_take_bounds_check(values: &dyn Array, indices: &UInt32Array) {
50+
criterion::black_box(
51+
take(values, &indices, Some(TakeOptions { check_bounds: true })).unwrap(),
52+
);
53+
}
54+
4955
fn add_benchmark(c: &mut Criterion) {
5056
let values = create_primitive_array::<Int32Type>(512, 0.0);
5157
let indices = create_random_index(512, 0.0);
@@ -56,6 +62,17 @@ fn add_benchmark(c: &mut Criterion) {
5662
b.iter(|| bench_take(&values, &indices))
5763
});
5864

65+
let values = create_primitive_array::<Int32Type>(512, 0.0);
66+
let indices = create_random_index(512, 0.0);
67+
c.bench_function("take check bounds i32 512", |b| {
68+
b.iter(|| bench_take_bounds_check(&values, &indices))
69+
});
70+
let values = create_primitive_array::<Int32Type>(1024, 0.0);
71+
let indices = create_random_index(1024, 0.0);
72+
c.bench_function("take check bounds i32 1024", |b| {
73+
b.iter(|| bench_take_bounds_check(&values, &indices))
74+
});
75+
5976
let indices = create_random_index(512, 0.5);
6077
c.bench_function("take i32 nulls 512", |b| {
6178
b.iter(|| bench_take(&values, &indices))

arrow/src/compute/kernels/take.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,30 @@ where
100100
let options = options.unwrap_or_default();
101101
if options.check_bounds {
102102
let len = values.len();
103-
for i in 0..indices.len() {
104-
if indices.is_valid(i) {
105-
let ix = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
103+
if indices.null_count() > 0 {
104+
indices.iter().flatten().try_for_each(|index| {
105+
let ix = ToPrimitive::to_usize(&index).ok_or_else(|| {
106106
ArrowError::ComputeError("Cast to usize failed".to_string())
107107
})?;
108108
if ix >= len {
109109
return Err(ArrowError::ComputeError(
110-
format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
111-
);
110+
format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
111+
);
112112
}
113-
}
113+
Ok(())
114+
})?;
115+
} else {
116+
indices.values().iter().try_for_each(|index| {
117+
let ix = ToPrimitive::to_usize(index).ok_or_else(|| {
118+
ArrowError::ComputeError("Cast to usize failed".to_string())
119+
})?;
120+
if ix >= len {
121+
return Err(ArrowError::ComputeError(
122+
format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
123+
);
124+
}
125+
Ok(())
126+
})?
114127
}
115128
}
116129
match values.data_type() {

0 commit comments

Comments
 (0)