diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index dae954a579eb..011371a513a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -181,7 +181,7 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) } } -abstract class MultiLikeBase +sealed abstract class MultiLikeBase extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { protected def patterns: Seq[UTF8String] @@ -220,7 +220,7 @@ abstract class MultiLikeBase /** * Optimized version of LIKE ALL, when all pattern values are literal. */ -abstract class LikeAllBase extends MultiLikeBase { +sealed abstract class LikeAllBase extends MultiLikeBase { override def matches(exprValue: String): Any = { if (cache.forall(matchFunc(_, exprValue))) { @@ -276,7 +276,7 @@ case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends Like /** * Optimized version of LIKE ANY, when all pattern values are literal. */ -abstract class LikeAnyBase extends MultiLikeBase { +sealed abstract class LikeAnyBase extends MultiLikeBase { override def matches(exprValue: String): Any = { if (cache.exists(matchFunc(_, exprValue))) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 819bffeafb64..a40456da8297 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, _} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, MultiLikeBase, _} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /* * Optimization rules defined in this file should not affect the structure of the logical plan. @@ -634,36 +635,68 @@ object LikeSimplification extends Rule[LogicalPlan] { private val contains = "%([^_%]+)%".r private val equalTo = "([^_%]*)".r + private def simplifyLike( + input: Expression, pattern: String, escapeChar: Char = '\\'): Option[Expression] = { + if (pattern.contains(escapeChar)) { + // There are three different situations when pattern containing escapeChar: + // 1. pattern contains invalid escape sequence, e.g. 'm\aca' + // 2. pattern contains escaped wildcard character, e.g. 'ma\%ca' + // 3. pattern contains escaped escape character, e.g. 'ma\\ca' + // Although there are patterns can be optimized if we handle the escape first, we just + // skip this rule if pattern contains any escapeChar for simplicity. + None + } else { + pattern match { + case startsWith(prefix) => + Some(StartsWith(input, Literal(prefix))) + case endsWith(postfix) => + Some(EndsWith(input, Literal(postfix))) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) => + Some(And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))) + case contains(infix) => + Some(Contains(input, Literal(infix))) + case equalTo(str) => + Some(EqualTo(input, Literal(str))) + case _ => None + } + } + } + + private def simplifyMultiLike( + child: Expression, patterns: Seq[UTF8String], multi: MultiLikeBase): Expression = { + val (remainPatternMap, replacementMap) = + patterns.map { p => p -> simplifyLike(child, p.toString)}.partition(_._2.isEmpty) + val remainPatterns = remainPatternMap.map(_._1) + val replacements = replacementMap.map(_._2.get) + if (replacements.isEmpty) { + multi + } else { + multi match { + case l: LikeAll => And(replacements.reduceLeft(And), l.copy(patterns = remainPatterns)) + case l: NotLikeAll => + And(replacements.map(Not(_)).reduceLeft(And), l.copy(patterns = remainPatterns)) + case l: LikeAny => Or(replacements.reduceLeft(Or), l.copy(patterns = remainPatterns)) + case l: NotLikeAny => + Or(replacements.map(Not(_)).reduceLeft(Or), l.copy(patterns = remainPatterns)) + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case l @ Like(input, Literal(pattern, StringType), escapeChar) => if (pattern == null) { // If pattern is null, return null value directly, since "col like null" == null. Literal(null, BooleanType) } else { - pattern.toString match { - // There are three different situations when pattern containing escapeChar: - // 1. pattern contains invalid escape sequence, e.g. 'm\aca' - // 2. pattern contains escaped wildcard character, e.g. 'ma\%ca' - // 3. pattern contains escaped escape character, e.g. 'ma\\ca' - // Although there are patterns can be optimized if we handle the escape first, we just - // skip this rule if pattern contains any escapeChar for simplicity. - case p if p.contains(escapeChar) => l - case startsWith(prefix) => - StartsWith(input, Literal(prefix)) - case endsWith(postfix) => - EndsWith(input, Literal(postfix)) - // 'a%a' pattern is basically same with 'a%' && '%a'. - // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) => - And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), - And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) => - Contains(input, Literal(infix)) - case equalTo(str) => - EqualTo(input, Literal(str)) - case _ => l - } + simplifyLike(input, pattern.toString, escapeChar).getOrElse(l) } + case l @ LikeAll(child, patterns) => simplifyMultiLike(child, patterns, l) + case l @ NotLikeAll(child, patterns) => simplifyMultiLike(child, patterns, l) + case l @ LikeAny(child, patterns) => simplifyMultiLike(child, patterns, l) + case l @ NotLikeAny(child, patterns) => simplifyMultiLike(child, patterns, l) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index 1812dce0da42..c06c92f9c151 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -164,4 +164,72 @@ class LikeSimplificationSuite extends PlanTest { .analyze comparePlans(optimized5, correctAnswer5) } + + test("simplify LikeAll") { + val originalQuery = + testRelation + .where(('a likeAll( + "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where((((((StartsWith('a, "abc") && EndsWith('a, "xyz")) && + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) && + Contains('a, "mn")) && ('a === "")) && ('a === "abc")) && + ('a likeAll("abc\\%", "abc\\%def", "%mn\\%"))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify NotLikeAll") { + val originalQuery = + testRelation + .where(('a notLikeAll( + "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where((((((Not(StartsWith('a, "abc")) && Not(EndsWith('a, "xyz"))) && + Not(Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) && + Not(Contains('a, "mn"))) && Not('a === "")) && Not('a === "abc")) && + ('a notLikeAll("abc\\%", "abc\\%def", "%mn\\%"))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LikeAny") { + val originalQuery = + testRelation + .where(('a likeAny( + "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where((((((StartsWith('a, "abc") || EndsWith('a, "xyz")) || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) || + Contains('a, "mn")) || ('a === "")) || ('a === "abc")) || + ('a likeAny("abc\\%", "abc\\%def", "%mn\\%"))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify NotLikeAny") { + val originalQuery = + testRelation + .where(('a notLikeAny( + "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where((((((Not(StartsWith('a, "abc")) || Not(EndsWith('a, "xyz"))) || + Not(Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) || + Not(Contains('a, "mn"))) || Not('a === "")) || Not('a === "abc")) || + ('a notLikeAny("abc\\%", "abc\\%def", "%mn\\%"))) + .analyze + + comparePlans(optimized, correctAnswer) + } }