Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions crates/polars-expr/src/expressions/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ impl EvalExpr {
// Batch when the total number of inner elements exceeds IdxSize::MAX to avoid
// truncating offset/length casts below. Each batch covers a contiguous row-range
// whose accumulated inner element count stays within IdxSize::MAX.
if flattened_len > IdxSize::MAX as usize {
const LIMIT: usize = (IdxSize::MAX - 1) as usize;
if flattened_len > LIMIT {
let offsets = ca.offsets()?;
// offsets_slice[i] / offsets_slice[i+1] are the start/end of row i.
let offsets_slice = offsets.as_slice();
Expand All @@ -100,7 +101,7 @@ impl EvalExpr {
if batch_row_start >= ca.len() {
break;
}
let threshold = batch_inner_start + IdxSize::MAX as i64;
let threshold = batch_inner_start + LIMIT as i64;
// Binary search for the first row whose end offset exceeds the threshold.
// offsets_slice[batch_row_start+1..] holds end offsets for rows
// batch_row_start, batch_row_start+1, …; partition_point returns how many fit.
Expand Down Expand Up @@ -250,20 +251,59 @@ impl EvalExpr {
as_list: bool,
is_agg: bool,
) -> PolarsResult<Column> {
let df = DataFrame::empty_with_height(ca.len());
let ca = ca
.trim_lists_to_normalized_offsets()
.map_or(Cow::Borrowed(ca), Cow::Owned);

// Fast path: Empty or only nulls.
if ca.null_count() == ca.len() {
let name = self.output_field.name.clone();
return Ok(Column::full_null(name, ca.len(), self.output_field.dtype()));
}

let df = DataFrame::empty_with_height(ca.len());
let ca = ca
.trim_lists_to_normalized_offsets()
.map_or(Cow::Borrowed(ca), Cow::Owned);

// SAFETY:
// We may temporarily create lengths that exceed IDXSIZE
// If that happens we slice and process in batches.
unsafe { _set_check_length(false) };
let flattened = ca.get_inner().into_column();
unsafe { _set_check_length(true) };
let flattened_len = flattened.len();
let validity = ca.rechunk_validity();
let width = ca.width();

let limit = if cfg!(debug_assertions) {
std::env::var("POLARS_ARRAY_EVAL_IDX_SIZE_LIMIT")
.map(|v| v.parse::<usize>().unwrap())
.unwrap_or(IdxSize::MAX as usize - 1)
} else {
(IdxSize::MAX - 1) as usize
};

if flattened_len > limit && width > 0 {
Comment thread
ritchie46 marked this conversation as resolved.
if state.verbose() {
eprintln!("IdxSize limit hit; chunking branch hit");
}

let rows_per_batch = limit / width;
polars_ensure!(rows_per_batch > 0, ComputeError: "array elements larger than IdxSize::MAX are not supported");
let mut batch_results: VecDeque<Column> = VecDeque::new();
let mut batch_row_start = 0usize;

while batch_row_start < ca.len() {
let batch_len = (ca.len() - batch_row_start).min(rows_per_batch);
let batch = ca.slice(batch_row_start as i64, batch_len);
batch_results
.push_back(self.evaluate_on_array_chunked(&batch, state, as_list, is_agg)?);
batch_row_start += batch_len;
}

let mut out = batch_results.pop_front().unwrap();
for other in batch_results {
out.append_owned(other)?;
}
return Ok(out);
}

let may_fail_on_masked_out_elements = self.evaluation_is_fallible && ca.has_nulls();

Expand Down
10 changes: 7 additions & 3 deletions py-polars/src/polars/_utils/construction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Final, get_type_hints

import polars as pl
from polars._dependencies import _check_for_pydantic, pydantic

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,14 +70,17 @@ def is_sqlalchemy_row(value: Any) -> bool:
)


def get_first_non_none(values: Sequence[Any | None]) -> Any:
def get_first_non_none(values: Sequence[Any | None] | pl.Series) -> Any:
"""
Return the first value from a sequence that isn't None.

If sequence doesn't contain non-None values, return None.
"""
if values is not None:
return next((v for v in values if v is not None), None)
if isinstance(values, pl.Series):
if values.dtype == pl.Null or values.null_count() == len(values):
return None

return next((v for v in values if v is not None), None)


def nt_unpack(obj: Any) -> Any:
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/constructors/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def test_array_large_u64() -> None:
assert s.to_list() == values


def test_array_creation_idx_size() -> None:
s = pl.Series([None])
width = 2**31
s = s.new_from_index(0, width)
assert pl.Series("a", [s, s, s, s], dtype=pl.Array(pl.Null, width)).shape == (4,)


def test_series_init_ambiguous_datetime() -> None:
value = datetime(2001, 10, 28, 2)
dtype = pl.Datetime(time_zone="Europe/Belgrade")
Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/unit/operations/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

import datetime
from typing import Any
from typing import TYPE_CHECKING, Any

import pytest

import polars as pl
from polars.exceptions import ComputeError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from tests.conftest import PlMonkeyPatch


def test_arr_min_max() -> None:
s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2))
Expand Down Expand Up @@ -694,3 +697,25 @@ def test_array_get_broadcast_26217() -> None:
{"literal": [42, 13, 37, 13, 37, 42, 13]}, schema={"literal": pl.UInt8}
)
assert_frame_equal(out, expected)


@pytest.mark.may_fail_auto_streaming
@pytest.mark.debug
def test_array_idx_size_limit_eval(capfd: Any, plmonkeypatch: PlMonkeyPatch) -> None:
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
plmonkeypatch.setenv("POLARS_ARRAY_EVAL_IDX_SIZE_LIMIT", "20")
s = pl.Series([None])
width = 19
s = s.new_from_index(0, width)
assert (
pl.Series("a", [s, s, s, s], dtype=pl.Array(pl.Null, width))
.to_frame()
.select(pl.col("a").arr.eval(pl.element().len() * pl.element()))
.head(1)
.item()
.to_list()
== [None] * width
)

captured = capfd.readouterr().err
assert "IdxSize limit hit; chunking branch hit" in captured
Loading