diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 78040d99fb0a5..43b00e4b08ee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -27,20 +27,55 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.sql.sources import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { - case class SetInFilter[T <: Comparable[T]]( - valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { + case class InSetFilter[T <: Comparable[T]](valueSet: Set[T]) + extends UserDefinedPredicate[T] { + + private val min = valueSet.min + private val max = valueSet.max override def keep(value: T): Boolean = { value != null && valueSet.contains(value) } - override def canDrop(statistics: Statistics[T]): Boolean = false + override def canDrop(statistics: Statistics[T]): Boolean = { + statistics.getMax.compareTo(min) < 0 || statistics.getMin.compareTo(max) > 0 + } override def inverseCanDrop(statistics: Statistics[T]): Boolean = false } + abstract class StringFilter extends UserDefinedPredicate[Binary] { + override def canDrop(statistics: Statistics[Binary]): Boolean = false + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false + + def binaryToUTF8String(value: Binary): UTF8String = { + // This is a trick used in CatalystStringConverter to steal the underlying + // byte array of the binary without copying it. + val buffer = value.toByteBuffer + val offset = buffer.position() + val numBytes = buffer.limit() - buffer.position() + UTF8String.fromBytes(buffer.array(), offset, numBytes) + } + } + + case class StringStartsWithFilter(prefix: String) extends StringFilter { + private val strToCompare: UTF8String = UTF8String.fromString(prefix) + override def keep(value: Binary): Boolean = binaryToUTF8String(value).startsWith(strToCompare) + } + + case class StringEndsWithFilter(suffix: String) extends StringFilter { + private val strToCompare: UTF8String = UTF8String.fromString(suffix) + override def keep(value: Binary): Boolean = binaryToUTF8String(value).endsWith(strToCompare) + } + + case class StringContainsFilter(str: String) extends StringFilter { + private val strToCompare: UTF8String = UTF8String.fromString(str) + override def keep(value: Binary): Boolean = binaryToUTF8String(value).contains(strToCompare) + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -157,27 +192,54 @@ private[sql] object ParquetFilters { FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } + private val makeStringStartsFilter: PartialFunction[DataType, + (String, String) => FilterPredicate] = { + case StringType => + (n: String, v: String) => + FilterApi.userDefined(binaryColumn(n), + StringStartsWithFilter(v.asInstanceOf[java.lang.String])) + } + + private val makeStringEndsFilter: PartialFunction[DataType, + (String, String) => FilterPredicate] = { + case StringType => + (n: String, v: String) => + FilterApi.userDefined(binaryColumn(n), + StringEndsWithFilter(v.asInstanceOf[java.lang.String])) + } + + private val makeStringContainsFilter: PartialFunction[DataType, + (String, String) => FilterPredicate] = { + case StringType => + (n: String, v: String) => + FilterApi.userDefined(binaryColumn(n), + StringContainsFilter(v.asInstanceOf[java.lang.String])) + } + private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = { + case BooleanType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(booleanColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Boolean]])) case IntegerType => (n: String, v: Set[Any]) => - FilterApi.userDefined(intColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Integer]])) + FilterApi.userDefined(intColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Integer]])) case LongType => (n: String, v: Set[Any]) => - FilterApi.userDefined(longColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Long]])) + FilterApi.userDefined(longColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Long]])) case FloatType => (n: String, v: Set[Any]) => - FilterApi.userDefined(floatColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Float]])) + FilterApi.userDefined(floatColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Float]])) case DoubleType => (n: String, v: Set[Any]) => - FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]])) + FilterApi.userDefined(doubleColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Double]])) case StringType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), - SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))))) + InSetFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))))) case BinaryType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), - SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]])))) + InSetFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]])))) } /** @@ -209,6 +271,9 @@ private[sql] object ParquetFilters { case sources.IsNotNull(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) + case sources.In(name, values) => + makeInSet.lift(dataTypeOf(name)).map(_(name, values.toSet)) + case sources.EqualTo(name, value) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.Not(sources.EqualTo(name, value)) => @@ -229,6 +294,13 @@ private[sql] object ParquetFilters { case sources.GreaterThanOrEqual(name, value) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.StringStartsWith(name, value) => + makeStringStartsFilter.lift(dataTypeOf(name)).map(_(name, value)) + case sources.StringEndsWith(name, value) => + makeStringEndsFilter.lift(dataTypeOf(name)).map(_(name, value)) + case sources.StringContains(name, value) => + makeStringContainsFilter.lift(dataTypeOf(name)).map(_(name, value)) + case sources.And(lhs, rhs) => (createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 7a23f57f40392..5b9e3f08e22e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -112,6 +112,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === true, classOf[Eq[_]], true) checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + + checkFilterPredicate( + ('_1.in(true)).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], true) + checkFilterPredicate( + ('_1.in(false)).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], false) } } @@ -138,6 +143,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + + checkFilterPredicate( + ('_1.in(1, 2)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(1), Row(2))) + checkFilterPredicate( + ('_1.in(3, 4)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(3), Row(4))) } } @@ -164,6 +178,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + + checkFilterPredicate( + ('_1.in(1L, 2L)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(1L), Row(2L))) + checkFilterPredicate( + ('_1.in(3L, 4L)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(3L), Row(4L))) } } @@ -190,6 +213,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + + checkFilterPredicate( + ('_1.in(1.0f, 2.0f)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(1.0f), Row(2.0f))) + checkFilterPredicate( + ('_1.in(3.0f, 4.0f)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(3.0f), Row(4.0f))) } } @@ -216,6 +248,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + + checkFilterPredicate( + ('_1.in(1.0, 2.0)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(1.0), Row(2.0))) + checkFilterPredicate( + ('_1.in(3.0, 4.0)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(3.0), Row(4.0))) } } @@ -244,6 +285,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + + checkFilterPredicate( + ('_1.in("1", "2")).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row("1"), Row("2"))) + checkFilterPredicate( + ('_1.in("3", "4")).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row("3"), Row("4"))) + } + + withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString * 5 + "test"))) { implicit df => + checkFilterPredicate( + ('_1 contains "11").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "11111test") + + checkFilterPredicate( + ('_1 contains "2test").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "22222test") + + checkFilterPredicate( + ('_1 contains "3t").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "33333test") + + checkFilterPredicate( + ('_1 startsWith "22").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "22222test") + + checkFilterPredicate( + ('_1 endsWith "4test").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "44444test") + + checkFilterPredicate( + ('_1 endsWith "2test").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "22222test") } } @@ -278,6 +360,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate( '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + + checkFilterPredicate( + ('_1.in(1.b, 2.b)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(1.b), Row(2.b))) + checkFilterPredicate( + ('_1.in(3.b, 4.b)).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq(Row(3.b), Row(4.b))) } }