|
1 | 1 | use arrow::array::{BinaryArray, BinaryViewArray, PrimitiveArray}; |
2 | 2 | use polars_core::downcast_as_macro_arg_physical; |
3 | 3 | use polars_core::prelude::*; |
| 4 | +use polars_core::utils::{Container, first_non_null, last_non_null}; |
4 | 5 | use polars_utils::total_ord::TotalEq; |
5 | 6 | use row_encode::encode_rows_unordered; |
6 | 7 |
|
| 8 | +pub trait IndexOf { |
| 9 | + /// Find the index of a given value in the Series. |
| 10 | + fn index_of(&self, needle: Scalar) -> PolarsResult<Option<usize>>; |
| 11 | + /// Find the index of the first non-null value in the Series. |
| 12 | + fn index_of_first_not_null(&self) -> Option<usize>; |
| 13 | + /// Find the index of the last non-null value in the Series. |
| 14 | + fn index_of_last_not_null(&self) -> Option<usize>; |
| 15 | +} |
| 16 | + |
7 | 17 | /// Find the index of the value, or ``None`` if it can't be found. |
8 | 18 | fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray<DT>, value: AR::ValueT<'a>) -> Option<usize> |
9 | 19 | where |
@@ -117,103 +127,150 @@ macro_rules! try_index_of_numeric_ca { |
117 | 127 | }}; |
118 | 128 | } |
119 | 129 |
|
120 | | -/// Find the index of a given value (the first and only entry in `value_series`) |
121 | | -/// within the series. |
122 | | -pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>> { |
123 | | - polars_ensure!( |
124 | | - series.dtype() == needle.dtype(), |
125 | | - InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}", |
126 | | - series.dtype(), |
127 | | - needle.dtype(), |
128 | | - ); |
129 | | - |
130 | | - if series.is_empty() { |
131 | | - return Ok(None); |
132 | | - } |
| 130 | +impl IndexOf for Series { |
| 131 | + /// Find the index of a given value within the Series; if not found, returns None. |
| 132 | + fn index_of(&self, needle: Scalar) -> PolarsResult<Option<usize>> { |
| 133 | + let dtype = self.dtype(); |
| 134 | + polars_ensure!( |
| 135 | + dtype == needle.dtype(), |
| 136 | + InvalidOperation: "cannot call `index_of` with mismatched datatypes: {:?} and {:?}", |
| 137 | + self.dtype(), |
| 138 | + needle.dtype(), |
| 139 | + ); |
133 | 140 |
|
134 | | - // Series is not null, and the value is null: |
135 | | - if needle.is_null() { |
136 | | - let null_count = series.null_count(); |
137 | | - if null_count == 0 { |
| 141 | + if self.is_empty() { |
138 | 142 | return Ok(None); |
139 | | - } else if null_count == series.len() { |
140 | | - return Ok(Some(0)); |
141 | 143 | } |
142 | 144 |
|
143 | | - let mut index = 0; |
144 | | - for chunk in series.chunks() { |
145 | | - let length = chunk.len(); |
146 | | - if let Some(bitmap) = chunk.validity() { |
147 | | - let leading_ones = bitmap.leading_ones(); |
148 | | - if leading_ones < length { |
149 | | - return Ok(Some(index + leading_ones)); |
150 | | - } |
| 145 | + // series is null |
| 146 | + if self.dtype().is_null() { |
| 147 | + if needle.is_null() { |
| 148 | + return Ok((!self.is_empty()).then_some(0)); |
151 | 149 | } else { |
152 | | - index += length; |
| 150 | + return Ok(None); |
153 | 151 | } |
154 | 152 | } |
155 | | - return Ok(None); |
| 153 | + |
| 154 | + // value being searched for is null |
| 155 | + if needle.is_null() { |
| 156 | + let null_count = self.null_count(); |
| 157 | + if null_count == 0 { |
| 158 | + return Ok(None); |
| 159 | + } else if null_count == self.len() { |
| 160 | + return Ok(Some(0)); |
| 161 | + } |
| 162 | + let mut index = 0; |
| 163 | + for chunk in self.chunks() { |
| 164 | + let length = chunk.len(); |
| 165 | + if let Some(bitmap) = chunk.validity() { |
| 166 | + let leading_ones = bitmap.leading_ones(); |
| 167 | + if leading_ones < length { |
| 168 | + return Ok(Some(index + leading_ones)); |
| 169 | + } |
| 170 | + } else { |
| 171 | + index += length; |
| 172 | + } |
| 173 | + } |
| 174 | + return Ok(None); |
| 175 | + } |
| 176 | + |
| 177 | + use DataType as DT; |
| 178 | + match self.dtype().to_physical() { |
| 179 | + DT::Null => unreachable!("handled above"), |
| 180 | + DT::Boolean => Ok(index_of_bool( |
| 181 | + self.bool()?, |
| 182 | + needle.value().extract_bool().unwrap(), |
| 183 | + )), |
| 184 | + dt if dt.is_primitive_numeric() => { |
| 185 | + let series = self.to_physical_repr(); |
| 186 | + Ok(downcast_as_macro_arg_physical!( |
| 187 | + series, |
| 188 | + try_index_of_numeric_ca, |
| 189 | + needle |
| 190 | + )) |
| 191 | + }, |
| 192 | + DT::String => Ok(index_of_value::<_, BinaryViewArray>( |
| 193 | + &self.str()?.as_binary(), |
| 194 | + needle.value().extract_str().unwrap().as_bytes(), |
| 195 | + )), |
| 196 | + DT::Binary => Ok(index_of_value::<_, BinaryViewArray>( |
| 197 | + self.binary()?, |
| 198 | + needle.value().extract_bytes().unwrap(), |
| 199 | + )), |
| 200 | + DT::BinaryOffset => Ok(index_of_value::<_, BinaryArray<i64>>( |
| 201 | + self.binary_offset()?, |
| 202 | + needle.value().extract_bytes().unwrap(), |
| 203 | + )), |
| 204 | + DT::Array(_, _) | DT::List(_) | DT::Struct(_) => { |
| 205 | + // For non-numeric dtypes, we convert to row-encoding, which essentially has |
| 206 | + // us searching the physical representation of the data as a series of |
| 207 | + // bytes. |
| 208 | + let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1); |
| 209 | + let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?; |
| 210 | + let value = value_as_row_encoded_ca |
| 211 | + .first() |
| 212 | + .expect("shouldn't have null values in a row-encoded result"); |
| 213 | + |
| 214 | + let ca = encode_rows_unordered(&[self.clone().into_column()])?; |
| 215 | + Ok(index_of_value::<_, BinaryArray<i64>>(&ca, value)) |
| 216 | + }, |
| 217 | + |
| 218 | + DT::UInt8 |
| 219 | + | DT::UInt16 |
| 220 | + | DT::UInt32 |
| 221 | + | DT::UInt64 |
| 222 | + | DT::Int8 |
| 223 | + | DT::Int16 |
| 224 | + | DT::Int32 |
| 225 | + | DT::Int64 |
| 226 | + | DT::Int128 |
| 227 | + | DT::Float32 |
| 228 | + | DT::Float64 => unreachable!("primitive numeric"), |
| 229 | + |
| 230 | + // to_physical |
| 231 | + #[cfg(feature = "dtype-decimal")] |
| 232 | + DT::Decimal(..) => unreachable!(), |
| 233 | + #[cfg(feature = "dtype-categorical")] |
| 234 | + DT::Categorical(..) | DT::Enum(..) => unreachable!(), |
| 235 | + DT::Date | DT::Datetime(..) | DT::Duration(..) | DT::Time => unreachable!(), |
| 236 | + |
| 237 | + DT::Object(_) | DT::Unknown(_) => polars_bail!(op = "index_of", self.dtype()), |
| 238 | + } |
| 239 | + } |
| 240 | + |
| 241 | + /// Find the index of the *first* non-null value in the |
| 242 | + /// Series; if no such value is found, returns None. |
| 243 | + fn index_of_first_not_null(&self) -> Option<usize> { |
| 244 | + // early-exit if empty, all values are null, or no values are null |
| 245 | + let n_values = self.len(); |
| 246 | + if n_values == 0 { |
| 247 | + return None; |
| 248 | + } |
| 249 | + let null_count = self.null_count(); |
| 250 | + if null_count == 0 { |
| 251 | + return Some(0); |
| 252 | + } else if null_count == n_values { |
| 253 | + return None; |
| 254 | + } |
| 255 | + // otherwise examine chunk validity bitmaps |
| 256 | + first_non_null(self.chunks().iter().map(|arr| arr.validity())) |
156 | 257 | } |
157 | 258 |
|
158 | | - use DataType as DT; |
159 | | - match series.dtype().to_physical() { |
160 | | - DT::Null => unreachable!("handled above"), |
161 | | - DT::Boolean => Ok(index_of_bool( |
162 | | - series.bool()?, |
163 | | - needle.value().extract_bool().unwrap(), |
164 | | - )), |
165 | | - dt if dt.is_primitive_numeric() => { |
166 | | - let series = series.to_physical_repr(); |
167 | | - Ok(downcast_as_macro_arg_physical!( |
168 | | - series, |
169 | | - try_index_of_numeric_ca, |
170 | | - needle |
171 | | - )) |
172 | | - }, |
173 | | - DT::String => Ok(index_of_value::<_, BinaryViewArray>( |
174 | | - &series.str()?.as_binary(), |
175 | | - needle.value().extract_str().unwrap().as_bytes(), |
176 | | - )), |
177 | | - DT::Binary => Ok(index_of_value::<_, BinaryViewArray>( |
178 | | - series.binary()?, |
179 | | - needle.value().extract_bytes().unwrap(), |
180 | | - )), |
181 | | - DT::BinaryOffset => Ok(index_of_value::<_, BinaryArray<i64>>( |
182 | | - series.binary_offset()?, |
183 | | - needle.value().extract_bytes().unwrap(), |
184 | | - )), |
185 | | - DT::Array(_, _) | DT::List(_) | DT::Struct(_) => { |
186 | | - // For non-numeric dtypes, we convert to row-encoding, which essentially has |
187 | | - // us searching the physical representation of the data as a series of |
188 | | - // bytes. |
189 | | - let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1); |
190 | | - let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?; |
191 | | - let value = value_as_row_encoded_ca |
192 | | - .first() |
193 | | - .expect("Shouldn't have nulls in a row-encoded result"); |
194 | | - let ca = encode_rows_unordered(&[series.clone().into_column()])?; |
195 | | - Ok(index_of_value::<_, BinaryArray<i64>>(&ca, value)) |
196 | | - }, |
197 | | - |
198 | | - DT::UInt8 |
199 | | - | DT::UInt16 |
200 | | - | DT::UInt32 |
201 | | - | DT::UInt64 |
202 | | - | DT::Int8 |
203 | | - | DT::Int16 |
204 | | - | DT::Int32 |
205 | | - | DT::Int64 |
206 | | - | DT::Int128 |
207 | | - | DT::Float32 |
208 | | - | DT::Float64 => unreachable!("primitive numeric"), |
209 | | - |
210 | | - // to_physical |
211 | | - #[cfg(feature = "dtype-decimal")] |
212 | | - DT::Decimal(..) => unreachable!(), |
213 | | - #[cfg(feature = "dtype-categorical")] |
214 | | - DT::Categorical(..) | DT::Enum(..) => unreachable!(), |
215 | | - DT::Date | DT::Datetime(..) | DT::Duration(..) | DT::Time => unreachable!(), |
216 | | - |
217 | | - DT::Object(_) | DT::Unknown(_) => polars_bail!(op = "index_of", series.dtype()), |
| 259 | + /// Find the index of the *last* non-null value in the |
| 260 | + /// Series; if no such value is found, returns None. |
| 261 | + fn index_of_last_not_null(&self) -> Option<usize> { |
| 262 | + // early-exit if empty, all values are null, or no values are null |
| 263 | + let n_values = self.len(); |
| 264 | + if n_values == 0 { |
| 265 | + return None; |
| 266 | + } |
| 267 | + let null_count = self.null_count(); |
| 268 | + if null_count == 0 { |
| 269 | + return Some(n_values - 1); |
| 270 | + } else if null_count == n_values { |
| 271 | + return None; |
| 272 | + } |
| 273 | + // otherwise examine chunk validity bitmaps |
| 274 | + last_non_null(self.chunks().iter().map(|arr| arr.validity()), n_values) |
218 | 275 | } |
219 | 276 | } |
0 commit comments