Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector.expressions.filter;

import java.io.Serializable;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.NamedReference;
Expand All @@ -27,7 +29,7 @@
* @since 3.3.0
*/
@Evolving
public abstract class Filter implements Expression {
public abstract class Filter implements Expression, Serializable {

protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.{NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED, LEGACY_CTE_PRECEDENCE_POLICY}
import org.apache.spark.sql.sources.Filter
Expand Down Expand Up @@ -1136,6 +1137,11 @@ object QueryCompilationErrors {
s"Fail to rebuild expression: missing key $filter in `translatedFilterToExpr`")
}

def failedToRebuildExpressionError(filter: V2Filter): Throwable = {
new AnalysisException(
s"Fail to rebuild expression: missing key $filter in `translatedFilterToExpr`")
}

def dataTypeUnsupportedByDataSourceError(format: String, field: StructField): Throwable = {
new AnalysisException(
s"$format data source does not support ${field.dataType.catalogString} data type.")
Expand Down Expand Up @@ -2392,4 +2398,12 @@ object QueryCompilationErrors {
errorClass = "INVALID_JSON_SCHEMA_MAPTYPE",
messageParameters = Array(schema.toString))
}

def invalidDataTypeForFilterValue(value: Any): Throwable = {
new AnalysisException(s"Filter value $value has invalid data type")
}

def invalidFilter(filter: Any): Throwable = {
new AnalysisException(s"Invalid Filter $filter")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.datasources

import java.math.{BigDecimal => JavaBigDecimal, BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.util.Locale

import scala.collection.JavaConverters._
Expand All @@ -26,20 +29,23 @@ import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.SparkUpgradeException
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.{sources, SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime}
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils


object DataSourceUtils extends PredicateHelper {
/**
* The key to use for storing partitionBy columns as options.
Expand Down Expand Up @@ -261,4 +267,116 @@ object DataSourceUtils extends PredicateHelper {
dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
(ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters)
}

def convertV1FilterToV2(v1Filter: sources.Filter): V2Filter = {
v1Filter match {
case _: sources.AlwaysFalse =>
new V2AlwaysFalse
case _: sources.AlwaysTrue =>
new V2AlwaysTrue
case e: sources.EqualNullSafe =>
new V2EqualNullSafe(FieldReference(e.attribute), getLiteralValue(e.value))
case equal: sources.EqualTo =>
new V2EqualTo(FieldReference(equal.attribute), getLiteralValue(equal.value))
case g: sources.GreaterThan =>
new V2GreaterThan(FieldReference(g.attribute), getLiteralValue(g.value))
case ge: sources.GreaterThanOrEqual =>
new V2GreaterThanOrEqual(FieldReference(ge.attribute), getLiteralValue(ge.value))
case in: sources.In =>
new V2In(FieldReference(
in.attribute), in.values.map(value => getLiteralValue(value)))
case notNull: sources.IsNotNull =>
new V2IsNotNull(FieldReference(notNull.attribute))
case isNull: sources.IsNull =>
new V2IsNull(FieldReference(isNull.attribute))
case l: sources.LessThan =>
new V2LessThan(FieldReference(l.attribute), getLiteralValue(l.value))
case le: sources.LessThanOrEqual =>
new V2LessThanOrEqual(FieldReference(le.attribute), getLiteralValue(le.value))
case contains: sources.StringContains =>
new V2StringContains(
FieldReference(contains.attribute), UTF8String.fromString(contains.value))
case ends: sources.StringEndsWith =>
new V2StringEndsWith(FieldReference(ends.attribute), UTF8String.fromString(ends.value))
case starts: sources.StringStartsWith =>
new V2StringStartsWith(
FieldReference(starts.attribute), UTF8String.fromString(starts.value))
case and: sources.And =>
new V2And(convertV1FilterToV2(and.left), convertV1FilterToV2(and.right))
case or: sources.Or =>
new V2Or(convertV1FilterToV2(or.left), convertV1FilterToV2(or.right))
case not: sources.Not =>
new V2Not(convertV1FilterToV2(not.child))
case _ => throw QueryCompilationErrors.invalidFilter(v1Filter)
Comment thread
huaxingao marked this conversation as resolved.
Outdated
}
}

def getLiteralValue(value: Any): LiteralValue[_] = value match {
case _: JavaBigDecimal =>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to DataSourceStrategy.translateLeafNodeFilter, the value of decimal type can only be Decimal

LiteralValue(Decimal(value.asInstanceOf[JavaBigDecimal]), DecimalType.SYSTEM_DEFAULT)
case _: JavaBigInteger =>
LiteralValue(Decimal(value.asInstanceOf[JavaBigInteger]), DecimalType.SYSTEM_DEFAULT)
case _: BigDecimal =>
LiteralValue(Decimal(value.asInstanceOf[BigDecimal]), DecimalType.SYSTEM_DEFAULT)
case _: Boolean => LiteralValue(value, BooleanType)
case _: Byte => LiteralValue(value, ByteType)
case _: Array[Byte] => LiteralValue(value, BinaryType)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be ambiguous, as it can be array type.

case _: Date =>
val date = DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
LiteralValue(date, DateType)
case _: LocalDate =>
val date = DateTimeUtils.localDateToDays(value.asInstanceOf[LocalDate])
LiteralValue(date, DateType)
case _: Double => LiteralValue(value, DoubleType)
case _: Float => LiteralValue(value, FloatType)
case _: Integer => LiteralValue(value, IntegerType)
case _: Long => LiteralValue(value, LongType)
case _: Short => LiteralValue(value, ShortType)
case _: String => LiteralValue(UTF8String.fromString(value.toString), StringType)
case _: Timestamp =>
val ts = DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[Timestamp])
LiteralValue(ts, TimestampType)
case _: Instant =>
val ts = DateTimeUtils.instantToMicros(value.asInstanceOf[Instant])
LiteralValue(ts, TimestampType)
case _ =>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about array/map/struct?

throw QueryCompilationErrors.invalidDataTypeForFilterValue(value)
Comment thread
huaxingao marked this conversation as resolved.
Outdated
}

def convertV2FilterToV1(v2Filter: V2Filter): sources.Filter = {
v2Filter match {
case _: V2AlwaysFalse => sources.AlwaysFalse
case _: V2AlwaysTrue => sources.AlwaysTrue
case e: V2EqualNullSafe => sources.EqualNullSafe(e.column.describe,
CatalystTypeConverters.convertToScala(e.value.value, e.value.dataType))
case equal: V2EqualTo => sources.EqualTo(equal.column.describe,
CatalystTypeConverters.convertToScala(equal.value.value, equal.value.dataType))
case g: V2GreaterThan => sources.GreaterThan(g.column.describe,
CatalystTypeConverters.convertToScala(g.value.value, g.value.dataType))
case ge: V2GreaterThanOrEqual => sources.GreaterThanOrEqual(ge.column.describe,
CatalystTypeConverters.convertToScala(ge.value.value, ge.value.dataType))
case in: V2In =>
var array: Array[Any] = Array.empty
for (value <- in.values) {
array = array :+ CatalystTypeConverters.convertToScala(value.value, value.dataType)
}
sources.In(in.column.describe, array)
case notNull: V2IsNotNull => sources.IsNotNull(notNull.column.describe)
case isNull: V2IsNull => sources.IsNull(isNull.column.describe)
case l: V2LessThan => sources.LessThan(l.column.describe,
CatalystTypeConverters.convertToScala(l.value.value, l.value.dataType))
case le: V2LessThanOrEqual => sources.LessThanOrEqual(le.column.describe,
CatalystTypeConverters.convertToScala(le.value.value, le.value.dataType))
case contains: V2StringContains =>
sources.StringContains(contains.column.describe, contains.value.toString)
case ends: V2StringEndsWith =>
sources.StringEndsWith(ends.column.describe, ends.value.toString)
case starts: V2StringStartsWith =>
sources.StringStartsWith(starts.column.describe, starts.value.toString)
case and: V2And => sources.And(convertV2FilterToV1(and.left), convertV2FilterToV1(and.right))
case or: V2Or => sources.Or(convertV2FilterToV1(or.left), convertV2FilterToV1(or.right))
case not: V2Not => sources.Not(convertV2FilterToV1(not.child))
case _ => throw QueryCompilationErrors.invalidFilter(v2Filter)
Comment thread
huaxingao marked this conversation as resolved.
Outdated
}
}
}
Loading