From 3906738a6e7ffe7bc8c50ad6e2c69c6bd7193cb2 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 6 Sep 2024 20:50:29 -0400 Subject: [PATCH 1/6] Start working on GetArrayStructFIelds --- .../apache/comet/serde/QueryPlanSerde.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d77fac4710..eb7528d226 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2508,6 +2508,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) None } + + // case GetArrayStructFields(child, field, ordinal, numFields, containsNull) => + // val childExpr = exprToProto(child, inputs, binding) + // val ordinalExpr = exprToProto(ordinal, inputs, binding) + + // if (childExpr.isDefined && ordinalExpr.isDefined) { + // val listExtractBuilder = ExprOuterClass.ListExtract + // .newBuilder() + // .setChild(childExpr.get) + // .setOrdinal(ordinalExpr.get) + // .setOneBased(false) + // .setFailOnError(failOnError) + + // Some( + // ExprOuterClass.Expr + // .newBuilder() + // .setListExtract(listExtractBuilder) + // .build()) + // } else { + // withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + // None + // } case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) From 6e289b351f0da1bfb81ff2fea15fef27c6915226 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sun, 8 Sep 2024 10:24:33 -0400 Subject: [PATCH 2/6] Almost have working --- .../core/src/execution/datafusion/planner.rs | 13 +- native/proto/src/proto/expr.proto | 6 + native/spark-expr/src/lib.rs | 2 +- native/spark-expr/src/list.rs | 132 +++++++++++++++++- .../apache/comet/serde/QueryPlanSerde.scala | 41 +++--- .../apache/comet/CometExpressionSuite.scala | 14 ++ 6 files changed, 182 insertions(+), 26 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index a305774397..fad4799bc7 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -96,8 +96,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, ListExtract, - MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, + Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr, + ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ @@ -680,6 +680,15 @@ impl PhysicalPlanner { expr.fail_on_error, ))) } + ExprStruct::GetArrayStructFields(expr) => { + let child = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; + + Ok(Arc::new(GetArrayStructFields::new( + child, + expr.ordinal as usize, + ))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 88940f386c..1a3e3c9fcd 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -81,6 +81,7 @@ message Expr { GetStructField get_struct_field = 54; ToJson to_json = 55; ListExtract list_extract = 56; + GetArrayStructFields get_array_struct_fields = 57; } } @@ -517,6 +518,11 @@ message ListExtract { bool fail_on_error = 5; } +message GetArrayStructFields { + Expr child = 1; + int32 ordinal = 2; +} + enum SortDirection { Ascending = 0; Descending = 1; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index c4b1c99ba9..cc22dfcbce 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -38,7 +38,7 @@ mod xxhash64; pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; -pub use list::ListExtract; +pub use list::{GetArrayStructFields, ListExtract}; pub use regexp::RLike; pub use structs::{CreateNamedStruct, GetStructField}; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 0b85a84248..ca6b683414 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -16,7 +16,7 @@ // under the License. use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; -use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; use arrow_schema::{DataType, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; @@ -275,6 +275,136 @@ impl PartialEq for ListExtract { } } +#[derive(Debug, Hash)] +pub struct GetArrayStructFields { + child: Arc, + ordinal: usize, +} + +impl GetArrayStructFields { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn list_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ListExtract: {:?}", + data_type + ))), + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.list_field(input_schema)?.data_type() { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ListExtract: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetArrayStructFields { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.list_field(input_schema)?.is_nullable() + || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(GetArrayStructFields::new( + Arc::clone(&children[0]), + self.ordinal, + ))), + _ => internal_err!("GetArrayStructFields should have exactly one child"), + } + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.ordinal.hash(&mut s); + self.hash(&mut s); + } +} + +impl Display for GetArrayStructFields { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetArrayStructFields [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} + +impl PartialEq for GetArrayStructFields { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal)) + .unwrap_or(false) + } +} + +fn get_array_struct_fields( + list_array: &GenericListArray, + ordinal: usize, +) -> DataFusionResult { + let values = list_array + .values() + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + let column = Arc::clone(values.column(ordinal)); + let field = Arc::clone(&values.fields()[ordinal]); + + let offsets = list_array.offsets(); + GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::clone(values.column(ordinal)))) +} + #[cfg(test)] mod test { use crate::list::{list_extract, zero_based_index}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index eb7528d226..2e57b454b7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2508,28 +2508,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) None } - - // case GetArrayStructFields(child, field, ordinal, numFields, containsNull) => - // val childExpr = exprToProto(child, inputs, binding) - // val ordinalExpr = exprToProto(ordinal, inputs, binding) - - // if (childExpr.isDefined && ordinalExpr.isDefined) { - // val listExtractBuilder = ExprOuterClass.ListExtract - // .newBuilder() - // .setChild(childExpr.get) - // .setOrdinal(ordinalExpr.get) - // .setOneBased(false) - // .setFailOnError(failOnError) - - // Some( - // ExprOuterClass.Expr - // .newBuilder() - // .setListExtract(listExtractBuilder) - // .build()) - // } else { - // withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) - // None - // } + + case GetArrayStructFields(child, _, ordinal, _, containsNull) => + val childExpr = exprToProto(child, inputs, binding) + + if (childExpr.isDefined) { + val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinal) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setGetArrayStructFields(arrayStructFieldsBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayStructFields", child) + None + } case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 3701be5fb6..848169a79e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2165,4 +2165,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("GetArrayStructFields") { + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + val df = spark.read.parquet(path.toString).select(array(struct(col("_2"), col("_3"), col("_4"))).alias("arr")) + checkSparkAnswerAndOperator(df.select("arr._2")) + // checkSparkAnswerAndOperator(df.select(array(struct(col("_8").alias("a")), struct(col("_13").alias("a"))))) + } + } + } + } } From 4ce5cc55c6df591057d53b53b9bb76ca54a5b9fb Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 11 Sep 2024 06:34:25 -0400 Subject: [PATCH 3/6] Working --- native/spark-expr/src/list.rs | 52 +++++++++++-------- .../apache/comet/CometExpressionSuite.scala | 15 ++++-- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index ca6b683414..a376198db7 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -290,7 +290,7 @@ impl GetArrayStructFields { match self.child.data_type(input_schema)? { DataType::List(field) | DataType::LargeList(field) => Ok(field), data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in ListExtract: {:?}", + "Unexpected data type in GetArrayStructFields: {:?}", data_type ))), } @@ -300,7 +300,7 @@ impl GetArrayStructFields { match self.list_field(input_schema)?.data_type() { DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in ListExtract: {:?}", + "Unexpected data type in GetArrayStructFields: {:?}", data_type ))), } @@ -313,7 +313,15 @@ impl PhysicalExpr for GetArrayStructFields { } fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.child_field(input_schema)?.data_type().clone()) + let struct_field = self.child_field(input_schema)?; + match self.child.data_type(input_schema)? { + DataType::List(_) => Ok(DataType::List(struct_field)), + DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } } fn nullable(&self, input_schema: &Schema) -> DataFusionResult { @@ -367,6 +375,25 @@ impl PhysicalExpr for GetArrayStructFields { } } +fn get_array_struct_fields( + list_array: &GenericListArray, + ordinal: usize, +) -> DataFusionResult { + let values = list_array + .values() + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + let column = Arc::clone(values.column(ordinal)); + let field = Arc::clone(&values.fields()[ordinal]); + + let offsets = list_array.offsets(); + let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::new(array))) +} + impl Display for GetArrayStructFields { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -386,25 +413,6 @@ impl PartialEq for GetArrayStructFields { } } -fn get_array_struct_fields( - list_array: &GenericListArray, - ordinal: usize, -) -> DataFusionResult { - let values = list_array - .values() - .as_any() - .downcast_ref::() - .expect("A struct is expected"); - - let column = Arc::clone(values.column(ordinal)); - let field = Arc::clone(&values.fields()[ordinal]); - - let offsets = list_array.offsets(); - GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); - - Ok(ColumnarValue::Array(Arc::clone(values.column(ordinal)))) -} - #[cfg(test)] mod test { use crate::list::{list_extract, zero_based_index}; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 848169a79e..0d24b3617b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2168,13 +2168,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("GetArrayStructFields") { Seq(true, false).foreach { dictionaryEnabled => - withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) { + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName, + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - val df = spark.read.parquet(path.toString).select(array(struct(col("_2"), col("_3"), col("_4"))).alias("arr")) - checkSparkAnswerAndOperator(df.select("arr._2")) - // checkSparkAnswerAndOperator(df.select(array(struct(col("_8").alias("a")), struct(col("_13").alias("a"))))) + val df = spark.read + .parquet(path.toString) + .select(array(struct(col("_2"), col("_3"), col("_4")), lit(null)).alias("arr")) + checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4")) + + // val df2 = + // spark.range(10).withColumn("arr", array(struct(lit(1).alias("a"), lit(2).alias("b")))) + // checkSparkAnswerAndOperator(df2.select("arr.a", "arr.b")) } } } From ab834075ae7fa7326448ce230a25a627946286ed Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sat, 28 Sep 2024 17:14:25 -0400 Subject: [PATCH 4/6] Add another test --- .../scala/org/apache/comet/CometExpressionSuite.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 437efcc549..e6cfa70875 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2264,12 +2264,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) val df = spark.read .parquet(path.toString) - .select(array(struct(col("_2"), col("_3"), col("_4")), lit(null)).alias("arr")) + .select( + array(struct(col("_2"), col("_3"), col("_4"), col("_8")), lit(null)).alias("arr")) checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4")) - // val df2 = - // spark.range(10).withColumn("arr", array(struct(lit(1).alias("a"), lit(2).alias("b")))) - // checkSparkAnswerAndOperator(df2.select("arr.a", "arr.b")) + val complex = spark.read + .parquet(path.toString) + .select(array(struct(struct(col("_4"), col("_8")).alias("nested"))).alias("arr")) + + checkSparkAnswerAndOperator(complex.select(col("arr.nested._4"))) } } } From 0168ef09075391c9d5d3201c3425c5d0212ea30e Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Mon, 30 Sep 2024 19:28:42 -0400 Subject: [PATCH 5/6] Remove unused --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e8138ee329..02b845e7c2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2542,7 +2542,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case GetArrayStructFields(child, _, ordinal, _, containsNull) => + case GetArrayStructFields(child, _, ordinal, _, _) => val childExpr = exprToProto(child, inputs, binding) if (childExpr.isDefined) { From 7fe6963618cfbbed74dabf3133e3980b5c95a4a7 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 1 Oct 2024 19:28:48 -0400 Subject: [PATCH 6/6] Remove unused sql conf --- .../test/scala/org/apache/comet/CometExpressionSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 664d0e1d80..da22df402b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2274,9 +2274,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("GetArrayStructFields") { Seq(true, false).foreach { dictionaryEnabled => - withSQLConf( - SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName, - CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true") { + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)