Skip to content

Commit 2b825ad

Browse files
fix: Propagate null in min_by / max_by for all-null by groups (#26919)
1 parent 1a6e6b9 commit 2b825ad

2 files changed

Lines changed: 61 additions & 10 deletions

File tree

crates/polars-expr/src/expressions/aggregation.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use polars_core::utils::{_split_offsets, NoNull};
99
use polars_ops::prelude::ArgAgg;
1010
#[cfg(feature = "propagate_nans")]
1111
use polars_ops::prelude::nan_propagating_aggregate;
12-
use polars_utils::itertools::Itertools;
1312
use rayon::prelude::*;
1413

1514
use super::*;
@@ -712,21 +711,23 @@ impl PhysicalExpr for AggMinMaxByExpr {
712711
unsafe { by_col.agg_arg_min(&by_groups) }
713712
};
714713
let idxs_in_groups: &IdxCa = idxs_in_groups.as_materialized_series().as_ref().as_ref();
715-
let flat_gather_idxs = match input_groups.as_ref().as_ref() {
714+
let gather_idxs: IdxCa = match input_groups.as_ref().as_ref() {
716715
GroupsType::Idx(g) => idxs_in_groups
717-
.into_no_null_iter()
716+
.iter()
718717
.enumerate()
719-
.map(|(group_idx, idx_in_group)| g.all()[group_idx][idx_in_group as usize])
720-
.collect_vec(),
718+
.map(|(group_idx, idx_in_group)| {
719+
idx_in_group.map(|i| g.all()[group_idx][i as usize])
720+
})
721+
.collect(),
721722
GroupsType::Slice { groups, .. } => idxs_in_groups
722-
.into_no_null_iter()
723+
.iter()
723724
.enumerate()
724-
.map(|(group_idx, idx_in_group)| groups[group_idx][0] + idx_in_group)
725-
.collect_vec(),
725+
.map(|(group_idx, idx_in_group)| idx_in_group.map(|i| groups[group_idx][0] + i))
726+
.collect(),
726727
};
727728

728-
// SAFETY: All indices are within input_col's groups.
729-
let gathered = unsafe { input_col.take_slice_unchecked(&flat_gather_idxs) };
729+
// SAFETY: All non-null indices are within input_col's groups.
730+
let gathered = unsafe { input_col.take_unchecked(&gather_idxs) };
730731
let agg_state = AggregatedScalar(gathered.with_name(keep_name));
731732
Ok(AggregationContext::from_agg_state(
732733
agg_state,

py-polars/tests/unit/operations/aggregation/test_aggregations.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,3 +1512,53 @@ def test_min_max_by_on_boolean_26847(
15121512
df = pl.DataFrame({"a": [1] * 10, "b": [True] * 10})
15131513
result = df.select(agg(pl.col("a"), pl.col("b")))
15141514
assert result.item() == expected
1515+
1516+
1517+
@pytest.mark.parametrize("agg", [pl.Expr.min_by, pl.Expr.max_by])
1518+
def test_min_max_by_all_null_by_group(agg: Callable[..., pl.Expr]) -> None:
1519+
df = pl.DataFrame(
1520+
{
1521+
"g": ["a", "a", "b"],
1522+
"val": [1, 2, 3],
1523+
"by": pl.Series([None, None, 5], dtype=pl.Int64),
1524+
}
1525+
)
1526+
expected = pl.DataFrame(
1527+
{"g": ["a", "b"], "val": pl.Series([None, 3], dtype=pl.Int64)}
1528+
)
1529+
1530+
eager = df.group_by("g", maintain_order=True).agg(agg(pl.col("val"), pl.col("by")))
1531+
assert_frame_equal(eager, expected)
1532+
1533+
streaming = (
1534+
df.lazy()
1535+
.group_by("g", maintain_order=True)
1536+
.agg(agg(pl.col("val"), pl.col("by")))
1537+
.collect(engine="streaming")
1538+
)
1539+
assert_frame_equal(streaming, expected)
1540+
1541+
1542+
@pytest.mark.parametrize("agg", [pl.Expr.min_by, pl.Expr.max_by])
1543+
def test_min_max_by_all_null_by_group_slice(agg: Callable[..., pl.Expr]) -> None:
1544+
df = pl.DataFrame(
1545+
{
1546+
"dt": [date(2020, 1, 1), date(2020, 1, 1), date(2020, 2, 1)],
1547+
"val": [1, 2, 3],
1548+
"by": pl.Series([None, None, 5], dtype=pl.Int64),
1549+
}
1550+
)
1551+
expected = pl.DataFrame(
1552+
{
1553+
"dt": [date(2020, 1, 1), date(2020, 2, 1)],
1554+
"val": pl.Series([None, 3], dtype=pl.Int64),
1555+
}
1556+
)
1557+
1558+
result = (
1559+
df.lazy()
1560+
.group_by_dynamic("dt", every="1mo")
1561+
.agg(agg(pl.col("val"), pl.col("by")))
1562+
.collect()
1563+
)
1564+
assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)