Skip to content

Commit 8b54ff0

Browse files
feat: add index_of_first_not_null and index_of_last_not_null Expr and Series methods
1 parent 5993d95 commit 8b54ff0

17 files changed

Lines changed: 557 additions & 117 deletions

File tree

crates/polars-ops/src/series/ops/index_of.rs

Lines changed: 146 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
use arrow::array::{BinaryArray, BinaryViewArray, PrimitiveArray};
22
use polars_core::downcast_as_macro_arg_physical;
33
use polars_core::prelude::*;
4+
use polars_core::utils::{Container, first_non_null, last_non_null};
45
use polars_utils::total_ord::TotalEq;
56
use row_encode::encode_rows_unordered;
67

8+
pub trait IndexOf {
9+
/// Find the index of a given value in the Series.
10+
fn index_of(&self, needle: Scalar) -> PolarsResult<Option<usize>>;
11+
/// Find the index of the first non-null value in the Series.
12+
fn index_of_first_not_null(&self) -> Option<usize>;
13+
/// Find the index of the last non-null value in the Series.
14+
fn index_of_last_not_null(&self) -> Option<usize>;
15+
}
16+
717
/// Find the index of the value, or ``None`` if it can't be found.
818
fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray<DT>, value: AR::ValueT<'a>) -> Option<usize>
919
where
@@ -117,103 +127,150 @@ macro_rules! try_index_of_numeric_ca {
117127
}};
118128
}
119129

120-
/// Find the index of a given value (the first and only entry in `value_series`)
121-
/// within the series.
122-
pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>> {
123-
polars_ensure!(
124-
series.dtype() == needle.dtype(),
125-
InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}",
126-
series.dtype(),
127-
needle.dtype(),
128-
);
129-
130-
if series.is_empty() {
131-
return Ok(None);
132-
}
130+
impl IndexOf for Series {
131+
/// Find the index of a given value within the Series; if not found, returns None.
132+
fn index_of(&self, needle: Scalar) -> PolarsResult<Option<usize>> {
133+
let dtype = self.dtype();
134+
polars_ensure!(
135+
dtype == needle.dtype(),
136+
InvalidOperation: "cannot call `index_of` with mismatched datatypes: {:?} and {:?}",
137+
self.dtype(),
138+
needle.dtype(),
139+
);
133140

134-
// Series is not null, and the value is null:
135-
if needle.is_null() {
136-
let null_count = series.null_count();
137-
if null_count == 0 {
141+
if self.is_empty() {
138142
return Ok(None);
139-
} else if null_count == series.len() {
140-
return Ok(Some(0));
141143
}
142144

143-
let mut index = 0;
144-
for chunk in series.chunks() {
145-
let length = chunk.len();
146-
if let Some(bitmap) = chunk.validity() {
147-
let leading_ones = bitmap.leading_ones();
148-
if leading_ones < length {
149-
return Ok(Some(index + leading_ones));
150-
}
145+
// series is null
146+
if self.dtype().is_null() {
147+
if needle.is_null() {
148+
return Ok((!self.is_empty()).then_some(0));
151149
} else {
152-
index += length;
150+
return Ok(None);
153151
}
154152
}
155-
return Ok(None);
153+
154+
// value being searched for is null
155+
if needle.is_null() {
156+
let null_count = self.null_count();
157+
if null_count == 0 {
158+
return Ok(None);
159+
} else if null_count == self.len() {
160+
return Ok(Some(0));
161+
}
162+
let mut index = 0;
163+
for chunk in self.chunks() {
164+
let length = chunk.len();
165+
if let Some(bitmap) = chunk.validity() {
166+
let leading_ones = bitmap.leading_ones();
167+
if leading_ones < length {
168+
return Ok(Some(index + leading_ones));
169+
}
170+
} else {
171+
index += length;
172+
}
173+
}
174+
return Ok(None);
175+
}
176+
177+
use DataType as DT;
178+
match self.dtype().to_physical() {
179+
DT::Null => unreachable!("handled above"),
180+
DT::Boolean => Ok(index_of_bool(
181+
self.bool()?,
182+
needle.value().extract_bool().unwrap(),
183+
)),
184+
dt if dt.is_primitive_numeric() => {
185+
let series = self.to_physical_repr();
186+
Ok(downcast_as_macro_arg_physical!(
187+
series,
188+
try_index_of_numeric_ca,
189+
needle
190+
))
191+
},
192+
DT::String => Ok(index_of_value::<_, BinaryViewArray>(
193+
&self.str()?.as_binary(),
194+
needle.value().extract_str().unwrap().as_bytes(),
195+
)),
196+
DT::Binary => Ok(index_of_value::<_, BinaryViewArray>(
197+
self.binary()?,
198+
needle.value().extract_bytes().unwrap(),
199+
)),
200+
DT::BinaryOffset => Ok(index_of_value::<_, BinaryArray<i64>>(
201+
self.binary_offset()?,
202+
needle.value().extract_bytes().unwrap(),
203+
)),
204+
DT::Array(_, _) | DT::List(_) | DT::Struct(_) => {
205+
// For non-numeric dtypes, we convert to row-encoding, which essentially has
206+
// us searching the physical representation of the data as a series of
207+
// bytes.
208+
let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1);
209+
let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?;
210+
let value = value_as_row_encoded_ca
211+
.first()
212+
.expect("shouldn't have null values in a row-encoded result");
213+
214+
let ca = encode_rows_unordered(&[self.clone().into_column()])?;
215+
Ok(index_of_value::<_, BinaryArray<i64>>(&ca, value))
216+
},
217+
218+
DT::UInt8
219+
| DT::UInt16
220+
| DT::UInt32
221+
| DT::UInt64
222+
| DT::Int8
223+
| DT::Int16
224+
| DT::Int32
225+
| DT::Int64
226+
| DT::Int128
227+
| DT::Float32
228+
| DT::Float64 => unreachable!("primitive numeric"),
229+
230+
// to_physical
231+
#[cfg(feature = "dtype-decimal")]
232+
DT::Decimal(..) => unreachable!(),
233+
#[cfg(feature = "dtype-categorical")]
234+
DT::Categorical(..) | DT::Enum(..) => unreachable!(),
235+
DT::Date | DT::Datetime(..) | DT::Duration(..) | DT::Time => unreachable!(),
236+
237+
DT::Object(_) | DT::Unknown(_) => polars_bail!(op = "index_of", self.dtype()),
238+
}
239+
}
240+
241+
/// Find the index of the *first* non-null value in the
242+
/// Series; if no such value is found, returns None.
243+
fn index_of_first_not_null(&self) -> Option<usize> {
244+
// early-exit if empty, all values are null, or no values are null
245+
let n_values = self.len();
246+
if n_values == 0 {
247+
return None;
248+
}
249+
let null_count = self.null_count();
250+
if null_count == 0 {
251+
return Some(0);
252+
} else if null_count == n_values {
253+
return None;
254+
}
255+
// otherwise examine chunk validity bitmaps
256+
first_non_null(self.chunks().iter().map(|arr| arr.validity()))
156257
}
157258

158-
use DataType as DT;
159-
match series.dtype().to_physical() {
160-
DT::Null => unreachable!("handled above"),
161-
DT::Boolean => Ok(index_of_bool(
162-
series.bool()?,
163-
needle.value().extract_bool().unwrap(),
164-
)),
165-
dt if dt.is_primitive_numeric() => {
166-
let series = series.to_physical_repr();
167-
Ok(downcast_as_macro_arg_physical!(
168-
series,
169-
try_index_of_numeric_ca,
170-
needle
171-
))
172-
},
173-
DT::String => Ok(index_of_value::<_, BinaryViewArray>(
174-
&series.str()?.as_binary(),
175-
needle.value().extract_str().unwrap().as_bytes(),
176-
)),
177-
DT::Binary => Ok(index_of_value::<_, BinaryViewArray>(
178-
series.binary()?,
179-
needle.value().extract_bytes().unwrap(),
180-
)),
181-
DT::BinaryOffset => Ok(index_of_value::<_, BinaryArray<i64>>(
182-
series.binary_offset()?,
183-
needle.value().extract_bytes().unwrap(),
184-
)),
185-
DT::Array(_, _) | DT::List(_) | DT::Struct(_) => {
186-
// For non-numeric dtypes, we convert to row-encoding, which essentially has
187-
// us searching the physical representation of the data as a series of
188-
// bytes.
189-
let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1);
190-
let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?;
191-
let value = value_as_row_encoded_ca
192-
.first()
193-
.expect("Shouldn't have nulls in a row-encoded result");
194-
let ca = encode_rows_unordered(&[series.clone().into_column()])?;
195-
Ok(index_of_value::<_, BinaryArray<i64>>(&ca, value))
196-
},
197-
198-
DT::UInt8
199-
| DT::UInt16
200-
| DT::UInt32
201-
| DT::UInt64
202-
| DT::Int8
203-
| DT::Int16
204-
| DT::Int32
205-
| DT::Int64
206-
| DT::Int128
207-
| DT::Float32
208-
| DT::Float64 => unreachable!("primitive numeric"),
209-
210-
// to_physical
211-
#[cfg(feature = "dtype-decimal")]
212-
DT::Decimal(..) => unreachable!(),
213-
#[cfg(feature = "dtype-categorical")]
214-
DT::Categorical(..) | DT::Enum(..) => unreachable!(),
215-
DT::Date | DT::Datetime(..) | DT::Duration(..) | DT::Time => unreachable!(),
216-
217-
DT::Object(_) | DT::Unknown(_) => polars_bail!(op = "index_of", series.dtype()),
259+
/// Find the index of the *last* non-null value in the
260+
/// Series; if no such value is found, returns None.
261+
fn index_of_last_not_null(&self) -> Option<usize> {
262+
// early-exit if empty, all values are null, or no values are null
263+
let n_values = self.len();
264+
if n_values == 0 {
265+
return None;
266+
}
267+
let null_count = self.null_count();
268+
if null_count == 0 {
269+
return Some(n_values - 1);
270+
} else if null_count == n_values {
271+
return None;
272+
}
273+
// otherwise examine chunk validity bitmaps
274+
last_non_null(self.chunks().iter().map(|arr| arr.validity()), n_values)
218275
}
219276
}

crates/polars-ops/src/series/ops/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ pub use fused::*;
8888
pub use horizontal::*;
8989
pub use index::*;
9090
#[cfg(feature = "index_of")]
91-
pub use index_of::*;
91+
pub use index_of::IndexOf;
9292
pub use int_range::*;
9393
#[cfg(feature = "interpolate")]
9494
pub use interpolation::interpolate::*;

crates/polars-plan/src/dsl/function_expr/index_of.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use polars_ops::series::index_of as index_of_op;
1+
use polars_ops::series::IndexOf;
22

33
use super::*;
44

@@ -49,13 +49,38 @@ pub(super) fn index_of(s: &mut [Column]) -> PolarsResult<Column> {
4949
}
5050
})
5151
},
52-
_ => index_of_op(series, needle)?,
52+
_ => <Series as IndexOf>::index_of(series, needle)?,
5353
};
5454

5555
let av = match result {
56+
Some(idx) => AnyValue::from(idx as IdxSize),
5657
None => AnyValue::Null,
58+
};
59+
let scalar = Scalar::new(IDX_DTYPE, av);
60+
Ok(Column::new_scalar(series.name().clone(), scalar, 1))
61+
}
62+
63+
fn index_to_column<F>(s: &mut [Column], func: F) -> PolarsResult<Column>
64+
where
65+
F: FnOnce(&Series) -> Option<usize>,
66+
{
67+
let series = if let Column::Scalar(ref sc) = s[0] {
68+
&sc.as_single_value_series()
69+
} else {
70+
s[0].as_materialized_series()
71+
};
72+
let av = match func(series) {
5773
Some(idx) => AnyValue::from(idx as IdxSize),
74+
None => AnyValue::Null,
5875
};
5976
let scalar = Scalar::new(IDX_DTYPE, av);
6077
Ok(Column::new_scalar(series.name().clone(), scalar, 1))
6178
}
79+
80+
pub(super) fn index_of_first_not_null(s: &mut [Column]) -> PolarsResult<Column> {
81+
index_to_column(s, <Series as IndexOf>::index_of_first_not_null)
82+
}
83+
84+
pub(super) fn index_of_last_not_null(s: &mut [Column]) -> PolarsResult<Column> {
85+
index_to_column(s, <Series as IndexOf>::index_of_last_not_null)
86+
}

crates/polars-plan/src/dsl/function_expr/mod.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ pub enum FunctionExpr {
159159
ArgWhere,
160160
#[cfg(feature = "index_of")]
161161
IndexOf,
162+
#[cfg(feature = "index_of")]
163+
IndexOfFirstNotNull,
164+
#[cfg(feature = "index_of")]
165+
IndexOfLastNotNull,
162166
#[cfg(feature = "search_sorted")]
163167
SearchSorted {
164168
side: SearchSortedSide,
@@ -401,6 +405,10 @@ impl Hash for FunctionExpr {
401405
Pow(f) => f.hash(state),
402406
#[cfg(feature = "index_of")]
403407
IndexOf => {},
408+
#[cfg(feature = "index_of")]
409+
IndexOfFirstNotNull => {},
410+
#[cfg(feature = "index_of")]
411+
IndexOfLastNotNull => {},
404412
#[cfg(feature = "search_sorted")]
405413
SearchSorted { side, descending } => {
406414
side.hash(state);
@@ -654,6 +662,10 @@ impl Display for FunctionExpr {
654662
ArgWhere => "arg_where",
655663
#[cfg(feature = "index_of")]
656664
IndexOf => "index_of",
665+
#[cfg(feature = "index_of")]
666+
IndexOfFirstNotNull => "index_of_first_not_null",
667+
#[cfg(feature = "index_of")]
668+
IndexOfLastNotNull => "index_of_last_not_null",
657669
#[cfg(feature = "search_sorted")]
658670
SearchSorted { .. } => "search_sorted",
659671
#[cfg(feature = "range")]
@@ -945,6 +957,14 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> {
945957
IndexOf => {
946958
map_as_slice!(index_of::index_of)
947959
},
960+
#[cfg(feature = "index_of")]
961+
IndexOfFirstNotNull => {
962+
map_as_slice!(index_of::index_of_first_not_null)
963+
},
964+
#[cfg(feature = "index_of")]
965+
IndexOfLastNotNull => {
966+
map_as_slice!(index_of::index_of_last_not_null)
967+
},
948968
#[cfg(feature = "search_sorted")]
949969
SearchSorted { side, descending } => {
950970
map_as_slice!(search_sorted::search_sorted_impl, side, descending)
@@ -1264,6 +1284,10 @@ impl FunctionExpr {
12641284
F::IndexOf => {
12651285
FunctionOptions::aggregation().with_casting_rules(CastingRules::FirstArgLossless)
12661286
},
1287+
#[cfg(feature = "index_of")]
1288+
F::IndexOfFirstNotNull => FunctionOptions::aggregation(),
1289+
#[cfg(feature = "index_of")]
1290+
F::IndexOfLastNotNull => FunctionOptions::aggregation(),
12671291
#[cfg(feature = "search_sorted")]
12681292
F::SearchSorted { .. } => FunctionOptions::groupwise().with_supertyping(
12691293
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(),

crates/polars-plan/src/dsl/function_expr/schema.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ impl FunctionExpr {
4848
ArgWhere => mapper.with_dtype(IDX_DTYPE),
4949
#[cfg(feature = "index_of")]
5050
IndexOf => mapper.with_dtype(IDX_DTYPE),
51+
#[cfg(feature = "index_of")]
52+
IndexOfFirstNotNull => mapper.with_dtype(IDX_DTYPE),
53+
#[cfg(feature = "index_of")]
54+
IndexOfLastNotNull => mapper.with_dtype(IDX_DTYPE),
5155
#[cfg(feature = "search_sorted")]
5256
SearchSorted { .. } => mapper.with_dtype(IDX_DTYPE),
5357
#[cfg(feature = "range")]

0 commit comments

Comments
 (0)