diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9510aa4d9e707..dc64e5e256052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } } + /** + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). + * + * @param rule a function used to generate alternatives for a node + * @return the stream of alternatives + */ + def multiTransformDown( + rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). + * + * As it is very easy to generate enormous number of alternatives when the input tree is huge or + * when the rule returns many alternatives for many nodes, this function returns the alternatives + * as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side + * as needed. + * + * The rule should not apply or can return a one element stream of original node to indicate that + * the original node without any transformation is a valid alternative. + * + * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this + * case `multiTransform()` returns an empty `Stream`. + * + * Please consider the following examples of `input.multiTransformDown(rule)`: + * + * We have an input expression: + * `Add(a, b)` + * + * 1. + * We have a simple rule: + * `a` => `Stream(1, 2)` + * `b` => `Stream(10, 20)` + * `Add(a, b)` => `Stream(11, 12, 21, 22)` + * + * The output is: + * `Stream(11, 12, 21, 22)` + * + * 2. + * In the previous example if we want to generate alternatives of `a` and `b` too then we need to + * explicitly add the original `Add(a, b)` expression to the rule: + * `a` => `Stream(1, 2)` + * `b` => `Stream(10, 20)` + * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))` + * + * The output is: + * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` + * + * @param rule a function used to generate alternatives for a node + * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false + * on a TreeNode T, skips processing T and its subtree; otherwise, processes + * T and its subtree recursively. + * @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is + * UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`) + * has been marked as in effective on a TreeNode T, skips processing T and its + * subtree. Do not pass it if the rule is not purely functional and reads a + * varying initial state for different invocations. + * @return the stream of alternatives + */ + def multiTransformDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId + )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + if (!cond.apply(this) || isRuleIneffective(ruleId)) { + return Stream(this) + } + + // We could return `Stream(this)` if the `rule` doesn't apply and handle both + // - the doesn't apply + // - and the rule returns a one element `Stream(originalNode)` + // cases together. But, unfortunately it doesn't seem like there is a way to match on a one + // element stream without eagerly computing the tail head. So this contradicts with the purpose + // of only taking the necessary elements from the alternatives. I.e. the + // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. + // Please note that this behaviour has a downside as well that we can only mark the rule on the + // original node ineffective if the rule didn't match. + var ruleApplied = true + val afterRules = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, (_: BaseType) => { + ruleApplied = false + Stream.empty + }) + } + + val afterRulesStream = if (afterRules.isEmpty) { + if (ruleApplied) { + // If the rule returned with empty alternatives then prune + Stream.empty + } else { + // If the rule was not applied then keep the original node + this.markRuleAsIneffective(ruleId) + Stream(this) + } + } else { + // If the rule was applied then use the returned alternatives + afterRules.map { afterRule => + if (this fastEquals afterRule) { + this + } else { + afterRule.copyTagsFrom(this) + afterRule + } + } + } + + afterRulesStream.flatMap { afterRule => + if (afterRule.containsChild.nonEmpty) { + generateChildrenSeq( + afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule))) + .map(afterRule.withNewChildren) + } else { + Stream(afterRule) + } + } + } + + private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = { + childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) => + for { + childrenSeq <- childrenSeqStream + child <- childrenStream + } yield child +: childrenSeq + ) + } + /** * Returns a copy of this node where `f` has been applied to all the nodes in `children`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 286d3dddae6ea..ac28917675e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -977,4 +977,108 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(origin.context.summary.isEmpty) } } + + private def newErrorAfterStream(es: Expression*) = { + es.toStream.append( + throw new NoSuchElementException("Stream should not return more elements") + ) + } + + test("multiTransformDown generates all alternatives") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => + Stream(Literal(100), Literal(200), Literal(300)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200), Literal(300)) + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), cd) + assert(transformed === expected) + } + + test("multiTransformDown is lazy") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => newErrorAfterStream(Literal(10)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected = for { + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, Literal(10)), Literal(100)) + // We don't access alternatives for `b` after 10 and for `c` after 100 + assert(transformed.take(3) == expected) + intercept[NoSuchElementException] { + transformed.take(3 + 1).toList + } + + val transformed2 = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected2 = for { + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), Literal(100)) + // We don't access alternatives for `c` after 100 + assert(transformed2.take(3 * 3) === expected2) + intercept[NoSuchElementException] { + transformed.take(3 * 3 + 1).toList + } + } + + test("multiTransformDown rule return this") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s) + case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s) + case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => + Stream(Literal(100), Literal(200), a) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d"))) + b <- Seq(Literal(10), Literal(20), Literal("b")) + a <- Seq(Literal(1), Literal(2), Literal("a")) + } yield Add(Add(a, b), cd) + assert(transformed == expected) + } + + test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " + + "transformed and itself is in the alternatives") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case a @ Add(StringLiteral("a"), StringLiteral("b"), _) => + Stream(Literal(11), Literal(12), Literal(21), Literal(22), a) + case StringLiteral("a") => Stream(Literal(1), Literal(2)) + case StringLiteral("b") => Stream(Literal(10), Literal(20)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200)) + ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++ + (for { + b <- Seq(Literal(10), Literal(20)) + a <- Seq(Literal(1), Literal(2)) + } yield Add(a, b)) + } yield Add(ab, cd) + assert(transformed == expected) + } + + test("multiTransformDown can prune") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream.empty + } + assert(transformed.isEmpty) + + val transformed2 = e.multiTransformDown { + case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty + } + assert(transformed2.isEmpty) + } }