Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 198 additions & 144 deletions native/Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ license = "Apache-2.0"
edition = "2021"

# Comet uses the same minimum Rust version as DataFusion
rust-version = "1.85"
rust-version = "1.86"

[workspace.dependencies]
arrow = { version = "55.2.0", features = ["prettyprint", "ffi", "chrono-tz"] }
arrow = { version = "56.0.0", features = ["prettyprint", "ffi", "chrono-tz"] }
async-trait = { version = "0.1" }
bytes = { version = "1.10.0" }
parquet = { version = "55.2.0", default-features = false, features = ["experimental"] }
datafusion = { version = "49.0.2", default-features = false, features = ["unicode_expressions", "crypto_expressions", "nested_expressions", "parquet"] }
datafusion-spark = { version = "49.0.2" }
parquet = { version = "=56.0.0", default-features = false, features = ["experimental"] }
datafusion = { version = "50.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "nested_expressions", "parquet"] }
datafusion-spark = { version = "50.0.0" }
datafusion-comet-spark-expr = { path = "spark-expr" }
datafusion-comet-proto = { path = "proto" }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
Expand Down
4 changes: 2 additions & 2 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ publish = false

[dependencies]
arrow = { workspace = true }
parquet = { workspace = true, default-features = false, features = ["experimental"] }
parquet = { workspace = true, default-features = false, features = ["experimental", "arrow"] }
futures = { workspace = true }
mimalloc = { version = "*", default-features = false, optional = true }
tikv-jemallocator = { version = "0.6.0", optional = true, features = ["disable_initial_exec_tls"] }
Expand Down Expand Up @@ -91,7 +91,7 @@ jni = { version = "0.21", features = ["invocation"] }
lazy_static = "1.4"
assertables = "9"
hex = "0.4.3"
datafusion-functions-nested = { version = "49.0.2" }
datafusion-functions-nested = { version = "50.0.0" }

[features]
default = []
Expand Down
9 changes: 4 additions & 5 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ use datafusion::execution::memory_pool::MemoryPool;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::logical_expr::ScalarUDF;
use datafusion::{
execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnv},
execution::disk_manager::DiskManagerBuilder,
physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream},
prelude::{SessionConfig, SessionContext},
};
use datafusion_comet_proto::spark_operator::Operator;
use datafusion_spark::function::hash::sha2::SparkSha2;
use datafusion_spark::function::math::expm1::SparkExpm1;
use datafusion_spark::function::string::char::SparkChar;
use datafusion_spark::function::string::char::CharFunc;
use futures::poll;
use futures::stream::StreamExt;
use jni::objects::JByteBuffer;
Expand Down Expand Up @@ -291,8 +291,7 @@ fn prepare_datafusion_session_context(
&ScalarValue::Float64(Some(1.1)),
);

#[allow(deprecated)]
let runtime = RuntimeEnv::try_new(rt_config)?;
let runtime = rt_config.build()?;

let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime));

Expand All @@ -301,7 +300,7 @@ fn prepare_datafusion_session_context(
// register UDFs from datafusion-spark crate
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkChar::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));

// Must be the last one to override existing functions with the same name
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;
Expand Down
42 changes: 34 additions & 8 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use datafusion::physical_plan::InputOrderMode;
use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
config::ConfigOptions,
execution::FunctionRegistry,
functions_aggregate::first_last::{FirstValue, LastValue},
logical_expr::Operator as DataFusionOperator,
Expand Down Expand Up @@ -622,8 +623,13 @@ impl PhysicalPlanner {
let args = vec![child];
let comet_hour = Arc::new(ScalarUDF::new_from_impl(SparkHour::new(timezone)));
let field_ref = Arc::new(Field::new("hour", DataType::Int32, true));
let expr: ScalarFunctionExpr =
ScalarFunctionExpr::new("hour", comet_hour, args, field_ref);
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
"hour",
comet_hour,
args,
field_ref,
Arc::new(ConfigOptions::default()),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of possibly instantiating multiple default ConfigOptions, in the future we stash one somewhere. This would have the benefits of:

  1. A custom config would propagate throughout
  2. Reduced memory overhead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of possibly instantiating multiple default ConfigOptions, in the future we stash one somewhere. This would have the benefits of:

  1. A custom config would propagate throughout
  2. Reduced memory overhead

That would be in the ExecutionContext perhaps?

);

Ok(Arc::new(expr))
}
Expand All @@ -634,8 +640,13 @@ impl PhysicalPlanner {
let args = vec![child];
let comet_minute = Arc::new(ScalarUDF::new_from_impl(SparkMinute::new(timezone)));
let field_ref = Arc::new(Field::new("minute", DataType::Int32, true));
let expr: ScalarFunctionExpr =
ScalarFunctionExpr::new("minute", comet_minute, args, field_ref);
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
"minute",
comet_minute,
args,
field_ref,
Arc::new(ConfigOptions::default()),
);

Ok(Arc::new(expr))
}
Expand All @@ -646,8 +657,13 @@ impl PhysicalPlanner {
let args = vec![child];
let comet_second = Arc::new(ScalarUDF::new_from_impl(SparkSecond::new(timezone)));
let field_ref = Arc::new(Field::new("second", DataType::Int32, true));
let expr: ScalarFunctionExpr =
ScalarFunctionExpr::new("second", comet_second, args, field_ref);
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
"second",
comet_second,
args,
field_ref,
Arc::new(ConfigOptions::default()),
);

Ok(Arc::new(expr))
}
Expand Down Expand Up @@ -869,8 +885,13 @@ impl PhysicalPlanner {
ScalarUDF::new_from_impl(BloomFilterMightContain::try_new(bloom_filter_expr)?);

let field_ref = Arc::new(Field::new("might_contain", DataType::Boolean, true));
let expr: ScalarFunctionExpr =
ScalarFunctionExpr::new("might_contain", Arc::new(udf), args, field_ref);
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
"might_contain",
Arc::new(udf),
args,
field_ref,
Arc::new(ConfigOptions::default()),
);
Ok(Arc::new(expr))
}
ExprStruct::CreateNamedStruct(expr) => {
Expand Down Expand Up @@ -1089,6 +1110,7 @@ impl PhysicalPlanner {
fun_expr,
vec![left, right],
Arc::new(Field::new(func_name, data_type, true)),
Arc::new(ConfigOptions::default()),
)))
}
_ => {
Expand All @@ -1114,6 +1136,7 @@ impl PhysicalPlanner {
fun_expr,
vec![left, right],
Arc::new(Field::new(op_str, data_type, true)),
Arc::new(ConfigOptions::default()),
)))
} else {
Ok(Arc::new(BinaryExpr::new(left, op, right)))
Expand Down Expand Up @@ -2375,6 +2398,8 @@ impl PhysicalPlanner {
window_frame.into(),
input_schema.as_ref(),
false, // TODO: Ignore nulls
false, // TODO: Spark does not support DISTINCT ... OVER
None,
)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}
Expand Down Expand Up @@ -2554,6 +2579,7 @@ impl PhysicalPlanner {
fun_expr,
args.to_vec(),
Arc::new(Field::new(fun_name, data_type, true)),
Arc::new(ConfigOptions::default()),
));

Ok(scalar_expr)
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/execution/shuffle/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ mod test {
#[tokio::test]
async fn shuffle_repartitioner_memory() {
let batch = create_batch(900);
assert_eq!(8376, batch.get_array_memory_size());
assert_eq!(8316, batch.get_array_memory_size()); // Not stable across Arrow versions

let memory_limit = 512 * 1024;
let num_partitions = 2;
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/agg_funcs/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use datafusion::logical_expr::Volatility::Immutable;
use DataType::*;

/// AVG aggregate expression
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Avg {
name: String,
signature: Signature,
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/agg_funcs/avg_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use num::{integer::div_ceil, Integer};
use DataType::*;

/// AVG aggregate expression
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AvgDecimal {
signature: Signature,
sum_data_type: DataType,
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/agg_funcs/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use datafusion::physical_expr::expressions::StatsType;
/// we have our own implementation is that DataFusion has UInt64 for state_field `count`,
/// while Spark has Double for count. Also we have added `null_on_divide_by_zero`
/// to be consistent with Spark's implementation.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Correlation {
name: String,
signature: Signature,
Expand Down
11 changes: 10 additions & 1 deletion native/spark-expr/src/agg_funcs/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,23 @@ use std::sync::Arc;
/// The implementation mostly is the same as the DataFusion's implementation. The reason
/// we have our own implementation is that DataFusion has UInt64 for state_field count,
/// while Spark has Double for count.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Covariance {
name: String,
signature: Signature,
stats_type: StatsType,
null_on_divide_by_zero: bool,
}

impl std::hash::Hash for Covariance {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
(self.stats_type as u8).hash(state);
Copy link
Contributor Author

@mbutrovich mbutrovich Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StatsType does not implement Hash, so here we do a cast. std::mem::discriminant is the solution if the enum gets any more complex.

self.null_on_divide_by_zero.hash(state);
}
}

impl Covariance {
/// Create a new COVAR aggregate function
pub fn new(
Expand Down
11 changes: 10 additions & 1 deletion native/spark-expr/src/agg_funcs/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,23 @@ use datafusion::physical_expr::expressions::StatsType;
/// we have our own implementation is that DataFusion has UInt64 for state_field `count`,
/// while Spark has Double for count. Also we have added `null_on_divide_by_zero`
/// to be consistent with Spark's implementation.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub struct Stddev {
name: String,
signature: Signature,
stats_type: StatsType,
null_on_divide_by_zero: bool,
}

impl std::hash::Hash for Stddev {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
(self.stats_type as u8).hash(state);
self.null_on_divide_by_zero.hash(state);
}
}

impl Stddev {
/// Create a new STDDEV aggregate function
pub fn new(
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/agg_funcs/sum_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion::logical_expr::{
};
use std::{any::Any, ops::BitAnd, sync::Arc};

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SumDecimal {
/// Aggregate function signature
signature: Signature,
Expand Down
11 changes: 10 additions & 1 deletion native/spark-expr/src/agg_funcs/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,23 @@ use std::sync::Arc;
/// we have our own implementation is that DataFusion has UInt64 for state_field `count`,
/// while Spark has Double for count. Also we have added `null_on_divide_by_zero`
/// to be consistent with Spark's implementation.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub struct Variance {
name: String,
signature: Signature,
stats_type: StatsType,
null_on_divide_by_zero: bool,
}

impl std::hash::Hash for Variance {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
(self.stats_type as u8).hash(state);
self.null_on_divide_by_zero.hash(state);
}
}

impl Variance {
/// Create a new VARIANCE aggregate function
pub fn new(
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/bitwise_funcs/bitwise_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkBitwiseCount {
signature: Signature,
aliases: Vec<String>,
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/bitwise_funcs/bitwise_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Vol
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkBitwiseGet {
signature: Signature,
aliases: Vec<String>,
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/bitwise_funcs/bitwise_not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use std::{any::Any, sync::Arc};

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkBitwiseNot {
signature: Signature,
aliases: Vec<String>,
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/bloom_filter/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion::physical_expr::expressions::Literal;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::Accumulator;

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BloomFilterAgg {
signature: Signature,
num_items: i32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::sync::Arc;

use crate::bloom_filter::spark_bloom_filter::SparkBloomFilter;

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct BloomFilterMightContain {
signature: Signature,
bloom_filter: Option<SparkBloomFilter>,
Expand Down
20 changes: 20 additions & 0 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,26 @@ struct CometScalarFunction {
func: ScalarFunctionImplementation,
}

impl PartialEq for CometScalarFunction {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.signature == other.signature
&& self.data_type == other.data_type
// Note: we do not test ScalarFunctionImplementation equality, relying on function metadata.
}
}

impl Eq for CometScalarFunction {}

impl std::hash::Hash for CometScalarFunction {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
self.data_type.hash(state);
// Note: we do not hash ScalarFunctionImplementation, relying on function metadata.
}
}

impl Debug for CometScalarFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CometScalarFunction")
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/datetime_funcs/date_trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::any::Any;

use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkDateTrunc {
signature: Signature,
aliases: Vec<String>,
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/datetime_funcs/extract_date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use std::{any::Any, fmt::Debug};

macro_rules! extract_date_part {
($struct_name:ident, $fn_name:expr, $date_part_variant:ident) => {
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct $struct_name {
signature: Signature,
aliases: Vec<String>,
Expand Down
Loading
Loading