Skip to content

Commit 5993d95

Browse files
authored
feat: Add per partition sort and finish callback to sinks (#22789)
1 parent e25db0b commit 5993d95

23 files changed

Lines changed: 1007 additions & 110 deletions

File tree

crates/polars-expr/src/reduce/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::marker::PhantomData;
1515
use arrow::array::{Array, PrimitiveArray, StaticArray};
1616
use arrow::bitmap::{Bitmap, BitmapBuilder, MutableBitmap};
1717
pub use convert::into_reduction;
18+
pub use min_max::{new_max_reduction, new_min_reduction};
1819
use polars_core::prelude::*;
1920

2021
use crate::EvictIdx;

crates/polars-lazy/src/frame/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,7 @@ impl LazyFrame {
10811081
/// final result doesn't fit into memory. This methods will return an error if the query cannot
10821082
/// be completely done in a streaming fashion.
10831083
#[cfg(feature = "parquet")]
1084+
#[allow(clippy::too_many_arguments)]
10841085
pub fn sink_parquet_partitioned(
10851086
self,
10861087
base_path: Arc<PathBuf>,
@@ -1089,6 +1090,8 @@ impl LazyFrame {
10891090
options: ParquetWriteOptions,
10901091
cloud_options: Option<polars_io::cloud::CloudOptions>,
10911092
sink_options: SinkOptions,
1093+
per_partition_sort_by: Option<Vec<SortColumn>>,
1094+
finish_callback: Option<SinkFinishCallback>,
10921095
) -> PolarsResult<Self> {
10931096
self.sink(SinkType::Partition(PartitionSinkType {
10941097
base_path,
@@ -1097,13 +1100,16 @@ impl LazyFrame {
10971100
variant,
10981101
file_type: FileType::Parquet(options),
10991102
cloud_options,
1103+
per_partition_sort_by,
1104+
finish_callback,
11001105
}))
11011106
}
11021107

11031108
/// Stream a query result into an ipc/arrow file in a partitioned manner. This is useful if the
11041109
/// final result doesn't fit into memory. This methods will return an error if the query cannot
11051110
/// be completely done in a streaming fashion.
11061111
#[cfg(feature = "ipc")]
1112+
#[allow(clippy::too_many_arguments)]
11071113
pub fn sink_ipc_partitioned(
11081114
self,
11091115
base_path: Arc<PathBuf>,
@@ -1112,6 +1118,8 @@ impl LazyFrame {
11121118
options: IpcWriterOptions,
11131119
cloud_options: Option<polars_io::cloud::CloudOptions>,
11141120
sink_options: SinkOptions,
1121+
per_partition_sort_by: Option<Vec<SortColumn>>,
1122+
finish_callback: Option<SinkFinishCallback>,
11151123
) -> PolarsResult<Self> {
11161124
self.sink(SinkType::Partition(PartitionSinkType {
11171125
base_path,
@@ -1120,13 +1128,16 @@ impl LazyFrame {
11201128
variant,
11211129
file_type: FileType::Ipc(options),
11221130
cloud_options,
1131+
per_partition_sort_by,
1132+
finish_callback,
11231133
}))
11241134
}
11251135

11261136
/// Stream a query result into an csv file in a partitioned manner. This is useful if the final
11271137
/// result doesn't fit into memory. This methods will return an error if the query cannot be
11281138
/// completely done in a streaming fashion.
11291139
#[cfg(feature = "csv")]
1140+
#[allow(clippy::too_many_arguments)]
11301141
pub fn sink_csv_partitioned(
11311142
self,
11321143
base_path: Arc<PathBuf>,
@@ -1135,6 +1146,8 @@ impl LazyFrame {
11351146
options: CsvWriterOptions,
11361147
cloud_options: Option<polars_io::cloud::CloudOptions>,
11371148
sink_options: SinkOptions,
1149+
per_partition_sort_by: Option<Vec<SortColumn>>,
1150+
finish_callback: Option<SinkFinishCallback>,
11381151
) -> PolarsResult<Self> {
11391152
self.sink(SinkType::Partition(PartitionSinkType {
11401153
base_path,
@@ -1143,13 +1156,16 @@ impl LazyFrame {
11431156
variant,
11441157
file_type: FileType::Csv(options),
11451158
cloud_options,
1159+
per_partition_sort_by,
1160+
finish_callback,
11461161
}))
11471162
}
11481163

11491164
/// Stream a query result into a JSON file in a partitioned manner. This is useful if the final
11501165
/// result doesn't fit into memory. This methods will return an error if the query cannot be
11511166
/// completely done in a streaming fashion.
11521167
#[cfg(feature = "json")]
1168+
#[allow(clippy::too_many_arguments)]
11531169
pub fn sink_json_partitioned(
11541170
self,
11551171
base_path: Arc<PathBuf>,
@@ -1158,6 +1174,8 @@ impl LazyFrame {
11581174
options: JsonWriterOptions,
11591175
cloud_options: Option<polars_io::cloud::CloudOptions>,
11601176
sink_options: SinkOptions,
1177+
per_partition_sort_by: Option<Vec<SortColumn>>,
1178+
finish_callback: Option<SinkFinishCallback>,
11611179
) -> PolarsResult<Self> {
11621180
self.sink(SinkType::Partition(PartitionSinkType {
11631181
base_path,
@@ -1166,6 +1184,8 @@ impl LazyFrame {
11661184
variant,
11671185
file_type: FileType::Json(options),
11681186
cloud_options,
1187+
per_partition_sort_by,
1188+
finish_callback,
11691189
}))
11701190
}
11711191

crates/polars-plan/src/dsl/options/sink.rs

Lines changed: 152 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::path::PathBuf;
44
use std::sync::Arc;
55

66
use polars_core::error::PolarsResult;
7+
use polars_core::frame::DataFrame;
78
use polars_core::prelude::DataType;
89
use polars_core::scalar::Scalar;
910
use polars_io::cloud::CloudOptions;
@@ -99,6 +100,13 @@ impl SinkTarget {
99100
)),
100101
}
101102
}
103+
104+
pub fn to_display_string(&self) -> String {
105+
match self {
106+
Self::Path(p) => p.display().to_string(),
107+
Self::Dyn(_) => "dynamic-target".to_string(),
108+
}
109+
}
102110
}
103111

104112
impl fmt::Debug for SinkTarget {
@@ -260,6 +268,47 @@ pub enum PartitionTargetCallback {
260268
Python(polars_utils::python_function::PythonFunction),
261269
}
262270

271+
#[cfg_attr(feature = "python", pyo3::pyclass)]
272+
pub struct SinkWritten {
273+
pub file_idx: usize,
274+
pub part_idx: usize,
275+
pub in_part_idx: usize,
276+
pub keys: Vec<PartitionTargetContextKey>,
277+
pub file_path: PathBuf,
278+
pub full_path: PathBuf,
279+
pub num_rows: usize,
280+
pub file_size: usize,
281+
pub gathered: Option<DataFrame>,
282+
}
283+
284+
#[cfg_attr(feature = "python", pyo3::pyclass)]
285+
pub struct SinkFinishContext {
286+
pub written: Vec<SinkWritten>,
287+
}
288+
289+
#[derive(Clone, Debug, PartialEq)]
290+
pub enum SinkFinishCallback {
291+
Rust(SpecialEq<Arc<dyn Fn(DataFrame) -> PolarsResult<()> + Send + Sync>>),
292+
#[cfg(feature = "python")]
293+
Python(polars_utils::python_function::PythonFunction),
294+
}
295+
296+
impl SinkFinishCallback {
297+
pub fn call(&self, df: DataFrame) -> PolarsResult<()> {
298+
match self {
299+
Self::Rust(f) => f(df),
300+
#[cfg(feature = "python")]
301+
Self::Python(f) => pyo3::Python::with_gil(|py| {
302+
let converter =
303+
polars_utils::python_convert_registry::get_python_convert_registry();
304+
let df = (converter.to_py.df)(Box::new(df) as Box<dyn std::any::Any>)?;
305+
f.call1(py, (df,))?;
306+
PolarsResult::Ok(())
307+
}),
308+
}
309+
}
310+
}
311+
263312
impl PartitionTargetCallback {
264313
pub fn call(&self, ctx: PartitionTargetContext) -> PolarsResult<SinkTarget> {
265314
match self {
@@ -277,6 +326,60 @@ impl PartitionTargetCallback {
277326
}
278327
}
279328

329+
#[cfg(feature = "serde")]
330+
impl serde::Serialize for SinkFinishCallback {
331+
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
332+
where
333+
S: serde::Serializer,
334+
{
335+
use serde::ser::Error;
336+
337+
#[cfg(feature = "python")]
338+
if let Self::Python(v) = self {
339+
return v.serialize(_serializer);
340+
}
341+
342+
Err(S::Error::custom(format!("cannot serialize {self:?}")))
343+
}
344+
}
345+
346+
#[cfg(feature = "serde")]
347+
impl<'de> serde::Deserialize<'de> for SinkFinishCallback {
348+
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
349+
where
350+
D: serde::Deserializer<'de>,
351+
{
352+
#[cfg(feature = "python")]
353+
{
354+
Ok(Self::Python(
355+
polars_utils::python_function::PythonFunction::deserialize(_deserializer)?,
356+
))
357+
}
358+
#[cfg(not(feature = "python"))]
359+
{
360+
use serde::de::Error;
361+
Err(D::Error::custom(
362+
"cannot deserialize PartitionOutputCallback",
363+
))
364+
}
365+
}
366+
}
367+
368+
#[cfg(feature = "dsl-schema")]
369+
impl schemars::JsonSchema for SinkFinishCallback {
370+
fn schema_name() -> String {
371+
"PartitionTargetCallback".to_owned()
372+
}
373+
374+
fn schema_id() -> std::borrow::Cow<'static, str> {
375+
std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "SinkFinishCallback"))
376+
}
377+
378+
fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
379+
Vec::<u8>::json_schema(generator)
380+
}
381+
}
382+
280383
#[cfg(feature = "serde")]
281384
impl<'de> serde::Deserialize<'de> for PartitionTargetCallback {
282385
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
@@ -331,6 +434,23 @@ impl schemars::JsonSchema for PartitionTargetCallback {
331434
}
332435
}
333436

437+
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
438+
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
439+
#[derive(Clone, Debug, PartialEq)]
440+
pub struct SortColumn {
441+
pub expr: Expr,
442+
pub descending: bool,
443+
pub nulls_last: bool,
444+
}
445+
446+
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
447+
#[derive(Clone, Debug, PartialEq)]
448+
pub struct SortColumnIR {
449+
pub expr: ExprIR,
450+
pub descending: bool,
451+
pub nulls_last: bool,
452+
}
453+
334454
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
335455
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
336456
#[derive(Clone, Debug, PartialEq)]
@@ -341,6 +461,8 @@ pub struct PartitionSinkType {
341461
pub sink_options: SinkOptions,
342462
pub variant: PartitionVariant,
343463
pub cloud_options: Option<polars_io::cloud::CloudOptions>,
464+
pub per_partition_sort_by: Option<Vec<SortColumn>>,
465+
pub finish_callback: Option<SinkFinishCallback>,
344466
}
345467

346468
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@@ -352,6 +474,8 @@ pub struct PartitionSinkTypeIR {
352474
pub sink_options: SinkOptions,
353475
pub variant: PartitionVariantIR,
354476
pub cloud_options: Option<polars_io::cloud::CloudOptions>,
477+
pub per_partition_sort_by: Option<Vec<SortColumnIR>>,
478+
pub finish_callback: Option<SinkFinishCallback>,
355479
}
356480

357481
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@@ -392,23 +516,44 @@ pub enum PartitionVariantIR {
392516
},
393517
}
394518

519+
#[cfg(feature = "cse")]
395520
impl SinkTypeIR {
396-
#[cfg(feature = "cse")]
397521
pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
398522
std::mem::discriminant(self).hash(state);
399523
match self {
400524
Self::Memory => {},
401525
Self::File(f) => f.hash(state),
402-
Self::Partition(f) => {
403-
f.file_type.hash(state);
404-
f.sink_options.hash(state);
405-
f.variant.traverse_and_hash(expr_arena, state);
406-
f.cloud_options.hash(state);
407-
},
526+
Self::Partition(f) => f.traverse_and_hash(expr_arena, state),
527+
}
528+
}
529+
}
530+
531+
#[cfg(feature = "cse")]
532+
impl PartitionSinkTypeIR {
533+
pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
534+
self.file_type.hash(state);
535+
self.sink_options.hash(state);
536+
self.variant.traverse_and_hash(expr_arena, state);
537+
self.cloud_options.hash(state);
538+
std::mem::discriminant(&self.per_partition_sort_by).hash(state);
539+
if let Some(v) = &self.per_partition_sort_by {
540+
v.len().hash(state);
541+
for v in v {
542+
v.traverse_and_hash(expr_arena, state);
543+
}
408544
}
409545
}
410546
}
411547

548+
#[cfg(feature = "cse")]
549+
impl SortColumnIR {
550+
pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
551+
self.expr.traverse_and_hash(expr_arena, state);
552+
self.descending.hash(state);
553+
self.nulls_last.hash(state);
554+
}
555+
}
556+
412557
impl PartitionVariantIR {
413558
#[cfg(feature = "cse")]
414559
pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {

crates/polars-plan/src/plans/conversion/dsl_to_ir.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,25 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
11861186
},
11871187
},
11881188
cloud_options: f.cloud_options,
1189+
per_partition_sort_by: match f.per_partition_sort_by {
1190+
None => None,
1191+
Some(sort_by) => Some(
1192+
sort_by
1193+
.into_iter()
1194+
.map(|s| {
1195+
let expr = to_expr_ir(s.expr, ctxt.expr_arena)?;
1196+
ctxt.conversion_optimizer
1197+
.push_scratch(expr.node(), ctxt.expr_arena);
1198+
Ok(SortColumnIR {
1199+
expr,
1200+
descending: s.descending,
1201+
nulls_last: s.nulls_last,
1202+
})
1203+
})
1204+
.collect::<PolarsResult<Vec<_>>>()?,
1205+
),
1206+
},
1207+
finish_callback: f.finish_callback,
11891208
}),
11901209
};
11911210

crates/polars-plan/src/plans/conversion/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,17 @@ impl IR {
302302
},
303303
},
304304
cloud_options: f.cloud_options,
305+
per_partition_sort_by: f.per_partition_sort_by.map(|sort_by| {
306+
sort_by
307+
.into_iter()
308+
.map(|s| SortColumn {
309+
expr: s.expr.to_expr(expr_arena),
310+
descending: s.descending,
311+
nulls_last: s.descending,
312+
})
313+
.collect()
314+
}),
315+
finish_callback: f.finish_callback,
305316
}),
306317
};
307318
DslPlan::Sink { input, payload }

0 commit comments

Comments
 (0)