diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 13f6b135fb..af93e5abd2 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -126,3 +126,7 @@ harness = false [[bench]] name = "aggregate" harness = false + +[[bench]] +name = "bloom_filter_agg" +harness = false diff --git a/native/core/benches/bloom_filter_agg.rs b/native/core/benches/bloom_filter_agg.rs new file mode 100644 index 0000000000..90e3e3f645 --- /dev/null +++ b/native/core/benches/bloom_filter_agg.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder}; + +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::builder::Int64Builder; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use comet::execution::datafusion::expressions::bloom_filter_agg::BloomFilterAgg; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_common::ScalarValue; +use datafusion_execution::TaskContext; +use datafusion_expr::AggregateUDF; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::{Column, Literal}; +use futures::StreamExt; +use std::sync::Arc; +use std::time::Duration; +use tokio::runtime::Runtime; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("bloom_filter_agg"); + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + // spark.sql.optimizer.runtime.bloomFilter.expectedNumItems + let num_items_sv = ScalarValue::Int64(Some(1000000_i64)); + let num_items: Arc = Arc::new(Literal::new(num_items_sv)); + //spark.sql.optimizer.runtime.bloomFilter.numBits + let num_bits_sv = ScalarValue::Int64(Some(8388608_i64)); + let num_bits: Arc = Arc::new(Literal::new(num_bits_sv)); + + let rt = Runtime::new().unwrap(); + + for agg_mode in [ + ("partial_agg", AggregateMode::Partial), + ("single_agg", AggregateMode::Single), + ] { + group.bench_function(agg_mode.0, |b| { + let comet_bloom_filter_agg = + Arc::new(AggregateUDF::new_from_impl(BloomFilterAgg::new( + Arc::clone(&c0), + Arc::clone(&num_items), + Arc::clone(&num_bits), + "bloom_filter_agg", + DataType::Binary, + ))); + b.to_async(&rt).iter(|| { + black_box(agg_test( + partitions, + c0.clone(), + comet_bloom_filter_agg.clone(), + "bloom_filter_agg", + agg_mode.1, + )) + }) + }); + } + + group.finish(); +} + +async fn agg_test( + partitions: &[Vec], + c0: Arc, + aggregate_udf: Arc, + alias: &str, + mode: AggregateMode, +) { + let schema = &partitions[0][0].schema(); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap()); + let aggregate = create_aggregate(scan, c0.clone(), schema, aggregate_udf, alias, mode); + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch.unwrap(); + } +} + +fn create_aggregate( + scan: Arc, + c0: Arc, + schema: &SchemaRef, + aggregate_udf: Arc, + alias: &str, + mode: AggregateMode, +) -> Arc { + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c0.clone()]) + .schema(schema.clone()) + .alias(alias) + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .unwrap(); + + Arc::new( + AggregateExec::try_new( + mode, + PhysicalGroupBy::new_single(vec![]), + vec![aggr_expr], + vec![None], + scan, + Arc::clone(schema), + ) + .unwrap(), + ) +} + +fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut int64_builder = Int64Builder::with_capacity(num_rows); + for i in 0..num_rows { + int64_builder.append_value(i as i64); + } + let int64_array = Arc::new(int64_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // int64 column + fields.push(Field::new("c0", DataType::Int64, false)); + columns.push(int64_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() +} + +fn config() -> Criterion { + Criterion::default() + .measurement_time(Duration::from_millis(500)) + .warm_up_time(Duration::from_millis(500)) +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs new file mode 100644 index 0000000000..ed64b80e78 --- /dev/null +++ b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::Field; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::util::spark_bloom_filter; +use crate::execution::datafusion::util::spark_bloom_filter::SparkBloomFilter; +use arrow::array::ArrayRef; +use arrow_array::BinaryArray; +use datafusion::error::Result; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_common::{downcast_value, DataFusionError, ScalarValue}; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + Accumulator, AggregateUDFImpl, Signature, +}; +use datafusion_physical_expr::expressions::Literal; + +#[derive(Debug, Clone)] +pub struct BloomFilterAgg { + name: String, + signature: Signature, + expr: Arc, + num_items: i32, + num_bits: i32, +} + +#[inline] +fn extract_i32_from_literal(expr: Arc) -> i32 { + match expr.as_any().downcast_ref::().unwrap().value() { + ScalarValue::Int64(scalar_value) => scalar_value.unwrap() as i32, + _ => { + unreachable!() + } + } +} + +impl BloomFilterAgg { + pub fn new( + expr: Arc, + num_items: Arc, + num_bits: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + assert!(matches!(data_type, DataType::Binary)); + Self { + name: name.into(), + signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable), + expr, + num_items: extract_i32_from_literal(num_items), + num_bits: extract_i32_from_literal(num_bits), + } + } +} + +impl AggregateUDFImpl for BloomFilterAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bloom_filter_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(SparkBloomFilter::from(( + spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits), + self.num_bits, + )))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new("bits", DataType::Binary, false)]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } +} + +impl Accumulator for SparkBloomFilter { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Int64(Some(value)) = v { + self.put_long(value); + } else { + unreachable!() + } + Ok(()) + }) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Binary(Some(self.spark_serialization()))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + // There might be a more efficient way to do this by transmuting since calling state() on an + // Accumulator is considered destructive. + let state_sv = ScalarValue::Binary(Some(self.state_as_bytes())); + Ok(vec![state_sv]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + assert_eq!( + states.len(), + 1, + "Expect one element in 'states' but found {}", + states.len() + ); + assert_eq!(states[0].len(), 1); + let state_sv = downcast_value!(states[0], BinaryArray); + self.merge_filter(state_sv.value_data()); + Ok(()) + } +} diff --git a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs index 462a22247f..de922d8312 100644 --- a/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs +++ b/native/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -72,7 +72,7 @@ fn evaluate_bloom_filter( let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?; match bloom_filter_bytes { ColumnarValue::Scalar(ScalarValue::Binary(v)) => { - Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes()))) + Ok(v.map(|v| SparkBloomFilter::from(v.as_bytes()))) } _ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"), } diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index 10c9d30920..48b80384b0 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -25,6 +25,7 @@ pub use normalize_nan::NormalizeNaNAndZero; use crate::errors::CometError; pub mod avg; pub mod avg_decimal; +pub mod bloom_filter_agg; pub mod bloom_filter_might_contain; pub mod comet_scalar_funcs; pub mod correlation; diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index d63fd70784..5b53cb3930 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -28,6 +28,7 @@ use crate::{ avg::Avg, avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, + bloom_filter_agg::BloomFilterAgg, bloom_filter_might_contain::BloomFilterMightContain, checkoverflow::CheckOverflow, correlation::Correlation, @@ -1620,6 +1621,22 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func) } + AggExprStruct::BloomFilterAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let num_items = + self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?; + let num_bits = + self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let func = AggregateUDF::new_from_impl(BloomFilterAgg::new( + Arc::clone(&child), + Arc::clone(&num_items), + Arc::clone(&num_bits), + "bloom_filter_agg", + datatype, + )); + Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) + } } } diff --git a/native/core/src/execution/datafusion/util/spark_bit_array.rs b/native/core/src/execution/datafusion/util/spark_bit_array.rs index 9729627df3..68b97d6608 100644 --- a/native/core/src/execution/datafusion/util/spark_bit_array.rs +++ b/native/core/src/execution/datafusion/util/spark_bit_array.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::common::bit; +use arrow_buffer::ToByteSlice; +use std::iter::zip; + /// A simple bit array implementation that simulates the behavior of Spark's BitArray which is /// used in the BloomFilter implementation. Some methods are not implemented as they are not /// required for the current use case. @@ -55,12 +59,50 @@ impl SparkBitArray { } pub fn bit_size(&self) -> u64 { - self.data.len() as u64 * 64 + self.word_size() as u64 * 64 + } + + pub fn byte_size(&self) -> usize { + self.word_size() * 8 + } + + pub fn word_size(&self) -> usize { + self.data.len() } pub fn cardinality(&self) -> usize { self.bit_count } + + pub fn to_bytes(&self) -> Vec { + Vec::from(self.data.to_byte_slice()) + } + + pub fn data(&self) -> Vec { + self.data.clone() + } + + // Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from an + // Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word vector. + pub fn merge_bits(&mut self, other: &[u8]) { + assert_eq!(self.byte_size(), other.len()); + let mut bit_count: usize = 0; + // For each word, merge the bits into self, and accumulate a new bit_count. + for i in zip( + self.data.iter_mut(), + other + .chunks(8) + .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())), + ) { + *i.0 |= i.1; + bit_count += i.0.count_ones() as usize; + } + self.bit_count = bit_count; + } +} + +pub fn num_words(num_bits: i32) -> i32 { + bit::ceil(num_bits as usize, 64) as i32 } #[cfg(test)] @@ -128,4 +170,67 @@ mod test { // check cardinality assert_eq!(array.cardinality(), 6); } + + #[test] + fn test_spark_bit_with_empty_buffer() { + let buf = vec![0u64; 4]; + let array = SparkBitArray::new(buf); + + assert_eq!(array.bit_size(), 256); + assert_eq!(array.cardinality(), 0); + + for n in 0..256 { + assert!(!array.get(n)); + } + } + + #[test] + fn test_spark_bit_with_full_buffer() { + let buf = vec![u64::MAX; 4]; + let array = SparkBitArray::new(buf); + + assert_eq!(array.bit_size(), 256); + assert_eq!(array.cardinality(), 256); + + for n in 0..256 { + assert!(array.get(n)); + } + } + + #[test] + fn test_spark_bit_merge() { + let buf1 = vec![0u64; 4]; + let mut array1 = SparkBitArray::new(buf1); + let buf2 = vec![0u64; 4]; + let mut array2 = SparkBitArray::new(buf2); + + let primes = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, + 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, + 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, + ]; + let fibs = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233]; + + for n in fibs { + array1.set(n); + } + + for n in primes { + array2.set(n); + } + + assert_eq!(array1.cardinality(), fibs.len()); + assert_eq!(array2.cardinality(), primes.len()); + + array1.merge_bits(array2.to_bytes().as_slice()); + + for n in fibs { + assert!(array1.get(n)); + } + + for n in primes { + assert!(array1.get(n)); + } + assert_eq!(array1.cardinality(), 60); + } } diff --git a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs index 00f717676d..22a84d8540 100644 --- a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs +++ b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::execution::datafusion::util::spark_bit_array; use crate::execution::datafusion::util::spark_bit_array::SparkBitArray; use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; +use arrow_buffer::ToByteSlice; use datafusion_comet_spark_expr::spark_hash::spark_compatible_murmur3_hash; +use std::cmp; const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; @@ -30,8 +33,29 @@ pub struct SparkBloomFilter { num_hash_functions: u32, } -impl SparkBloomFilter { - pub fn new(buf: &[u8]) -> Self { +pub fn optimal_num_hash_functions(expected_items: i32, num_bits: i32) -> i32 { + cmp::max( + 1, + ((num_bits as f64 / expected_items as f64) * 2.0_f64.ln()).round() as i32, + ) +} + +impl From<(i32, i32)> for SparkBloomFilter { + /// Creates an empty SparkBloomFilter given number of hash functions and bits. + fn from((num_hash_functions, num_bits): (i32, i32)) -> Self { + let num_words = spark_bit_array::num_words(num_bits); + let bits = vec![0u64; num_words as usize]; + Self { + bits: SparkBitArray::new(bits), + num_hash_functions: num_hash_functions as u32, + } + } +} + +impl From<&[u8]> for SparkBloomFilter { + /// Creates a SparkBloomFilter from a serialized byte array conforming to Spark's BloomFilter + /// binary format version 1. + fn from(buf: &[u8]) -> Self { let mut offset = 0; let version = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; @@ -54,6 +78,25 @@ impl SparkBloomFilter { num_hash_functions: num_hash_functions as u32, } } +} + +impl SparkBloomFilter { + /// Serializes a SparkBloomFilter to a byte array conforming to Spark's BloomFilter + /// binary format version 1. + pub fn spark_serialization(&self) -> Vec { + // There might be a more efficient way to do this, even with all the endianness stuff. + let mut spark_bloom_filter: Vec = 1_u32.to_be_bytes().to_vec(); + spark_bloom_filter.append(&mut self.num_hash_functions.to_be_bytes().to_vec()); + spark_bloom_filter.append(&mut (self.bits.word_size() as u32).to_be_bytes().to_vec()); + let mut filter_state: Vec = self.bits.data(); + for i in filter_state.iter_mut() { + *i = i.to_be(); + } + // Does it make sense to do a std::mem::take of filter_state here? Unclear to me if a deep + // copy of filter_state as a Vec to a Vec is happening here. + spark_bloom_filter.append(&mut Vec::from(filter_state.to_byte_slice())); + spark_bloom_filter + } pub fn put_long(&mut self, item: i64) -> bool { // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce @@ -94,4 +137,17 @@ impl SparkBloomFilter { .map(|v| v.map(|x| self.might_contain_long(x))) .collect() } + + pub fn state_as_bytes(&self) -> Vec { + self.bits.to_bytes() + } + + pub fn merge_filter(&mut self, other: &[u8]) { + assert_eq!( + other.len(), + self.bits.byte_size(), + "Cannot merge SparkBloomFilters with different lengths." + ); + self.bits.merge_bits(other); + } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 1a3e3c9fcd..796ca5be1b 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -101,6 +101,7 @@ message AggExpr { Variance variance = 13; Stddev stddev = 14; Correlation correlation = 15; + BloomFilterAgg bloomFilterAgg = 16; } } @@ -192,6 +193,13 @@ message Correlation { DataType datatype = 4; } +message BloomFilterAgg { + Expr child = 1; + Expr numItems = 2; + Expr numBits = 3; + DataType datatype = 4; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c6e692cc4a..3805d418b8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Complete, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, Complete, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -760,6 +760,39 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(aggExpr, child1, child2) None } + + case bloom_filter @ BloomFilterAggregate(child, numItems, numBits, _, _) => + // We ignore mutableAggBufferOffset and inputAggBufferOffset because they are + // implementation details for Spark's ObjectHashAggregate. + val childExpr = exprToProto(child, inputs, binding) + val numItemsExpr = exprToProto(numItems, inputs, binding) + val numBitsExpr = exprToProto(numBits, inputs, binding) + val dataType = serializeDataType(bloom_filter.dataType) + + // TODO: Support more types + // https://github.com/apache/datafusion-comet/issues/1023 + if (childExpr.isDefined && + child.dataType + .isInstanceOf[LongType] && + numItemsExpr.isDefined && + numBitsExpr.isDefined && + dataType.isDefined) { + val bloomFilterAggBuilder = ExprOuterClass.BloomFilterAgg.newBuilder() + bloomFilterAggBuilder.setChild(childExpr.get) + bloomFilterAggBuilder.setNumItems(numItemsExpr.get) + bloomFilterAggBuilder.setNumBits(numBitsExpr.get) + bloomFilterAggBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBloomFilterAgg(bloomFilterAggBuilder) + .build()) + } else { + withInfo(aggExpr, child, numItems, numBits) + None + } + case fn => val msg = s"unsupported Spark aggregate function: ${fn.prettyName}" emitWarning(msg) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 05aa237239..78f59cbea0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,10 +31,10 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame, DataFrameWriter, Row, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} -import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Hex} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometSparkToColumnarExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} @@ -911,6 +911,29 @@ class CometExecSuite extends CometTestBase { } } + test("bloom_filter_agg") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + withParquetTable( + (0 until 100) + .map(_ => (Random.nextInt(), Random.nextInt() % 5)), + "tbl") { + val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl") + checkSparkAnswerAndOperator(df) + } + + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + test("sort (non-global)") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala index 7205484e5c..3dd930f671 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala @@ -22,6 +22,9 @@ package org.apache.spark.sql.benchmark import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.internal.SQLConf import org.apache.comet.{CometConf, CometSparkSessionExtensions} @@ -222,23 +225,77 @@ object CometExecBenchmark extends CometBenchmarkBase { } } - override def runCometBenchmark(mainArgs: Array[String]): Unit = { - runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v => - subqueryExecBenchmark(v) - } + // BloomFilterAgg takes an argument for the expected number of distinct values, which determines filter size and + // number of hash functions. We use the cardinality as a hint to the aggregate, otherwise the default Spark values + // make a big filter with a lot of hash functions. + def bloomFilterAggregate(values: Int, cardinality: Int): Unit = { + val benchmark = + new Benchmark( + s"BloomFilterAggregate Exec (cardinality $cardinality)", + values, + output = output) - runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v => - expandExecBenchmark(v) - } + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => new BloomFilterAggregate(children.head, children(1))) + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable(dir, spark.sql(s"SELECT floor(rand() * $cardinality) as key FROM $tbl")) - runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v => - for (fractionOfZeros <- List(0.0, 0.50, 0.95)) { - numericFilterExecBenchmark(v, fractionOfZeros) + val query = + s"SELECT bloom_filter_agg(cast(key as long), cast($cardinality as long)) FROM parquetV1Table" + + benchmark.addCase("SQL Parquet - Spark (BloomFilterAgg)") { _ => + spark.sql(query).noop() + } + + benchmark.addCase("SQL Parquet - Comet (Scan) (BloomFilterAgg)") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "true") { + spark.sql(query).noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Scan, Exec) (BloomFilterAgg)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + spark.sql(query).noop() + } + } + + benchmark.run() } } - runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v => - sortExecBenchmark(v) + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { +// runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v => +// subqueryExecBenchmark(v) +// } +// +// runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v => +// expandExecBenchmark(v) +// } +// +// runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v => +// for (fractionOfZeros <- List(0.0, 0.50, 0.95)) { +// numericFilterExecBenchmark(v, fractionOfZeros) +// } +// } +// +// runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v => +// sortExecBenchmark(v) +// } + + runBenchmarkWithTable("BloomFilterAggregate", 1024 * 1024 * 10) { v => + for (card <- List(100, 1024, 1024 * 1024)) { + bloomFilterAggregate(v, card) + } } } }