Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,10 @@ trait Row extends Serializable {
case (r: Row, _) => r.jsonValue
case (v: Any, udt: UserDefinedType[Any @unchecked]) =>
val dataType = udt.sqlType
toJson(CatalystTypeConverters.convertToScala(udt.serialize(v), dataType), dataType)
toJson(CatalystTypeConverters.convertToScala(
udt.serialize(v),
dataType,
SQLConf.get.datetimeJava8ApiEnabled), dataType)
case _ =>
throw new IllegalArgumentException(s"Failed to convert value $value " +
s"(class of ${value.getClass}}) with the type of $dataType to JSON.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ object CatalystTypeConverters {
}
}

private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
private def getConverterForType(
dataType: DataType,
useJava8DateTimeApi: Boolean): CatalystTypeConverter[Any, Any, Any] = {
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
case structType: StructType => StructConverter(structType)
case arrayType: ArrayType => ArrayConverter(arrayType.elementType, useJava8DateTimeApi)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType, useJava8DateTimeApi)
case structType: StructType => StructConverter(structType, useJava8DateTimeApi)
case StringType => StringConverter
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
case TimestampType => TimestampConverter
case DateType => if (useJava8DateTimeApi) LocalDateConverter else DateConverter
case TimestampType => if (useJava8DateTimeApi) InstantConverter else TimestampConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
case ByteType => ByteConverter
Expand Down Expand Up @@ -156,9 +156,10 @@ object CatalystTypeConverters {

/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {
elementType: DataType,
useJava8DateTimeApi: Boolean) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {

private[this] val elementConverter = getConverterForType(elementType)
private[this] val elementConverter = getConverterForType(elementType, useJava8DateTimeApi)

override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
Expand Down Expand Up @@ -200,11 +201,12 @@ object CatalystTypeConverters {

private case class MapConverter(
keyType: DataType,
valueType: DataType)
valueType: DataType,
useJava8DateTimeApi: Boolean)
extends CatalystTypeConverter[Any, Map[Any, Any], MapData] {

private[this] val keyConverter = getConverterForType(keyType)
private[this] val valueConverter = getConverterForType(valueType)
private[this] val keyConverter = getConverterForType(keyType, useJava8DateTimeApi)
private[this] val valueConverter = getConverterForType(valueType, useJava8DateTimeApi)

override def toCatalystImpl(scalaValue: Any): MapData = {
val keyFunction = (k: Any) => keyConverter.toCatalyst(k)
Expand Down Expand Up @@ -240,9 +242,11 @@ object CatalystTypeConverters {
}

private case class StructConverter(
structType: StructType) extends CatalystTypeConverter[Any, Row, InternalRow] {
structType: StructType,
useJava8DateTimeApi: Boolean) extends CatalystTypeConverter[Any, Row, InternalRow] {

private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
private[this] val converters = structType.fields
.map { f => getConverterForType(f.dataType, useJava8DateTimeApi) }

override def toCatalystImpl(scalaValue: Any): InternalRow = scalaValue match {
case row: Row =>
Expand Down Expand Up @@ -404,7 +408,9 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
def createToCatalystConverter(dataType: DataType): Any => Any = {
def createToCatalystConverter(
dataType: DataType,
useJava8DateTimeApi: Boolean = SQLConf.get.datetimeJava8ApiEnabled): Any => Any = {
if (isPrimitive(dataType)) {
// Although the `else` branch here is capable of handling inbound conversion of primitives,
// we add some special-case handling for those types here. The motivation for this relates to
Expand All @@ -422,7 +428,7 @@ object CatalystTypeConverters {
}
convert
} else {
getConverterForType(dataType).toCatalyst
getConverterForType(dataType, useJava8DateTimeApi).toCatalyst
}
}

Expand All @@ -431,11 +437,13 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
def createToScalaConverter(dataType: DataType): Any => Any = {
def createToScalaConverter(
dataType: DataType,
useJava8DateTimeApi: Boolean = SQLConf.get.datetimeJava8ApiEnabled): Any => Any = {
if (isPrimitive(dataType)) {
identity
} else {
getConverterForType(dataType).toScala
getConverterForType(dataType, useJava8DateTimeApi).toScala
}
}

Expand Down Expand Up @@ -470,7 +478,7 @@ object CatalystTypeConverters {
* This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter.
*/
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
createToScalaConverter(dataType)(catalystValue)
def convertToScala(catalystValue: Any, dataType: DataType, useJava8DateTimeApi: Boolean): Any = {
createToScalaConverter(dataType, useJava8DateTimeApi)(catalystValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}
import java.time.LocalDate

import scala.language.implicitConversions

Expand Down Expand Up @@ -146,6 +147,7 @@ package object dsl {
implicit def doubleToLiteral(d: Double): Literal = Literal(d)
implicit def stringToLiteral(s: String): Literal = Literal.create(s, StringType)
implicit def dateToLiteral(d: Date): Literal = Literal(d)
implicit def localDateToLiteral(d: LocalDate): Literal = Literal(d)
implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying())
implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d)
implicit def decimalToLiteral(d: Decimal): Literal = Literal(d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,45 +450,45 @@ object DataSourceStrategy {

private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = predicate match {
case expressions.EqualTo(PushableColumn(name), Literal(v, t)) =>
Some(sources.EqualTo(name, convertToScala(v, t)))
Some(sources.EqualTo(name, convertToScala(v, t, false)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're treating this as a temp fix for Spark 3.0?
Looks like ideally we should support Java 8 datetime instances for this interface as well when spark.sql.datetime.java8API.enabled is enabled. It could cause more confusion. In addition, seems like spark.sql.datetime.java8API.enabled is disabled by default, too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's problematic to let the java8 config also control the value type inside Filter, as it can break existing DS v1 implementations. It's a bit unfortunate that we don't document clearly what the value type can be for Filter, but if we do, it's not user-friendly to say "the value type depends on xxx config". This just makes it harder to implement data source filter pushdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Taking into account #23811 (comment), the flag won't be enabled by default in the near future.

case expressions.EqualTo(Literal(v, t), PushableColumn(name)) =>
Some(sources.EqualTo(name, convertToScala(v, t)))
Some(sources.EqualTo(name, convertToScala(v, t, false)))

case expressions.EqualNullSafe(PushableColumn(name), Literal(v, t)) =>
Some(sources.EqualNullSafe(name, convertToScala(v, t)))
Some(sources.EqualNullSafe(name, convertToScala(v, t, false)))
case expressions.EqualNullSafe(Literal(v, t), PushableColumn(name)) =>
Some(sources.EqualNullSafe(name, convertToScala(v, t)))
Some(sources.EqualNullSafe(name, convertToScala(v, t, false)))

case expressions.GreaterThan(PushableColumn(name), Literal(v, t)) =>
Some(sources.GreaterThan(name, convertToScala(v, t)))
Some(sources.GreaterThan(name, convertToScala(v, t, false)))
case expressions.GreaterThan(Literal(v, t), PushableColumn(name)) =>
Some(sources.LessThan(name, convertToScala(v, t)))
Some(sources.LessThan(name, convertToScala(v, t, false)))

case expressions.LessThan(PushableColumn(name), Literal(v, t)) =>
Some(sources.LessThan(name, convertToScala(v, t)))
Some(sources.LessThan(name, convertToScala(v, t, false)))
case expressions.LessThan(Literal(v, t), PushableColumn(name)) =>
Some(sources.GreaterThan(name, convertToScala(v, t)))
Some(sources.GreaterThan(name, convertToScala(v, t, false)))

case expressions.GreaterThanOrEqual(PushableColumn(name), Literal(v, t)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t, false)))
case expressions.GreaterThanOrEqual(Literal(v, t), PushableColumn(name)) =>
Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
Some(sources.LessThanOrEqual(name, convertToScala(v, t, false)))

case expressions.LessThanOrEqual(PushableColumn(name), Literal(v, t)) =>
Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
Some(sources.LessThanOrEqual(name, convertToScala(v, t, false)))
case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t, false)))

case expressions.InSet(e @ PushableColumn(name), set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType, false)
Some(sources.In(name, set.toArray.map(toScala)))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(e @ PushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(_.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType, false)
Some(sources.In(name, hSet.toArray.map(toScala)))

case expressions.IsNull(PushableColumn(name)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.math.{BigDecimal => JBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.LocalDate

import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
Expand Down Expand Up @@ -525,52 +526,62 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
def date: Date = Date.valueOf(s)
}

val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21").map(_.date)
val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21")
import testImplicits._
withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) =>
withParquetDataFrame(inputDF) { implicit df =>
val dateAttr: Expression = df(colName).expr
assert(df(colName).expr.dataType === DateType)

checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
data.map(i => Row.apply(resultFun(i))))

checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i.date))))

checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
resultFun("2018-03-21".date))
checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
resultFun("2018-03-21".date))

checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]],
resultFun("2018-03-21".date))
checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]],
resultFun("2018-03-21".date))

checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
resultFun("2018-03-21".date))
checkFilterPredicate(
dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
classOf[Operators.Or],
Seq(Row(resultFun("2018-03-18".date)), Row(resultFun("2018-03-21".date))))
Seq(false, true).foreach { java8Api =>
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
val df = data.map(i => Tuple1(Date.valueOf(i))).toDF()
withNestedDataFrame(df) { case (inputDF, colName, fun) =>
def resultFun(dateStr: String): Any = {
val parsed = if (java8Api) LocalDate.parse(dateStr) else Date.valueOf(dateStr)
fun(parsed)
}
withParquetDataFrame(inputDF) { implicit df =>
val dateAttr: Expression = df(colName).expr
assert(df(colName).expr.dataType === DateType)

checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
data.map(i => Row.apply(resultFun(i))))

checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
resultFun("2018-03-18"))
checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
resultFun("2018-03-18"))
checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i))))

checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
resultFun("2018-03-18"))
checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
resultFun("2018-03-21"))
checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
resultFun("2018-03-18"))
checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
resultFun("2018-03-21"))

checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]],
resultFun("2018-03-18"))
checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]],
resultFun("2018-03-18"))
checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]],
resultFun("2018-03-18"))
checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]],
resultFun("2018-03-21"))
checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]],
resultFun("2018-03-18"))
checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]],
resultFun("2018-03-21"))

checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
resultFun("2018-03-21"))
checkFilterPredicate(
dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
classOf[Operators.Or],
Seq(Row(resultFun("2018-03-18")), Row(resultFun("2018-03-21"))))
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,26 +299,33 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
}

test("filter pushdown - date") {
val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day =>
val input = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day =>
Date.valueOf(day)
}
withOrcDataFrame(dates.map(Tuple1(_))) { implicit df =>
checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL)

checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS)
checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS)

checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN)

checkFilterPredicate(Literal(dates(0)) === $"_1", PredicateLeaf.Operator.EQUALS)
checkFilterPredicate(Literal(dates(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS)
checkFilterPredicate(Literal(dates(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate(Literal(dates(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(Literal(dates(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(Literal(dates(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN)
withOrcFile(input.map(Tuple1(_))) { path =>
Seq(false, true).foreach { java8Api =>
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
readFile(path) { implicit df =>
val dates = input.map(Literal(_))
checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL)

checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS)
checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS)

checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN)

checkFilterPredicate(dates(0) === $"_1", PredicateLeaf.Operator.EQUALS)
checkFilterPredicate(dates(0) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS)
checkFilterPredicate(dates(1) > $"_1", PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate(dates(2) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(dates(0) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(dates(3) <= $"_1", PredicateLeaf.Operator.LESS_THAN)
}
}
}
}
}

Expand Down
Loading