From 52546d0353bb3afc47a3b707809468be027129c3 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 28 Sep 2022 10:37:17 +0200 Subject: [PATCH 1/5] [SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate alternatives --- .../spark/sql/catalyst/trees/TreeNode.scala | 159 ++++++++++++++++++ .../sql/catalyst/trees/TreeNodeSuite.scala | 111 ++++++++++++ 2 files changed, 270 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9510aa4d9e707..1495172a13503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -618,6 +618,165 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } } + /** + * Returns alternative copies of this node where `rule` has been recursively applied to the tree. + * + * Users should not expect a specific directionality. If a specific directionality is needed, + * multiTransformDown or multiTransformUp should be used. + * + * @param rule a function used to generate transformed alternatives for a node + * @return the stream of alternatives + */ + def multiTransform(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { + multiTransformDown(rule) + } + + /** + * Returns alternative copies of this node where `rule` has been recursively applied to the tree. + * + * Users should not expect a specific directionality. If a specific directionality is needed, + * multiTransformDownWithPruning or multiTransformUpWithPruning should be used. + * + * @param rule a function used to generate transformed alternatives for a node + * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false + * on a TreeNode T, skips processing T and its subtree; otherwise, processes + * T and its subtree recursively. + * @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is + * UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`) + * has been marked as in effective on a TreeNode T, skips processing T and its + * subtree. Do not pass it if the rule is not purely functional and reads a + * varying initial state for different invocations. + * @return the stream of alternatives + */ + def multiTransformWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId + )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { + multiTransformDownWithPruning(cond, ruleId)(rule).map(_._1) + } + + /** + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). + * + * @param rule the function used to generate transformed alternatives for a node + * @return the stream of alternatives + */ + def multiTransformDown(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { + multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule).map(_._1) + } + + /** + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). + * + * As it is very easy to generate enormous number of alternatives when the input tree is huge or + * when the rule returns large number of alternatives, this function returns the alternatives as a + * lazy `Stream` to be able to limit the number of alternatives generated at the caller side as + * needed. + * + * To indicate that the original node without any transformation is a valid alternative the rule + * can either: + * - not apply or + * - return an empty `Seq` or + * - a `Seq` that contains a node that is equal to the original node. + * + * Please note that this function always consider the original node as a valid alternative (even + * if the original node is not included in the returned `Seq`) if the rule can transform any of + * the descendants of the node. E.g. consider a simple expression: + * `Add(a, b)` + * and a rule that returns: + * `Seq(1, 2)` for `a` and + * `Seq(10, 20)` for `b` and + * `Seq(11, 12, 21, 22)` for `Add(a, b)` (note that the original `Add(a, b)` is not returned) + * then the result of `multiTransform` is: + * `Seq(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`. + * This feature makes the usage of `multiTransform` easier as a non-leaf transforming rule doesn't + * need to take into account that it can transform a descendant node of the non-leaf node as well + * and so it doesn't need return the non-leaf node itself in the list of alternatives to not stop + * generating alternatives. + * + * @param rule a function used to generate transformed alternatives for a node + * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false + * on a TreeNode T, skips processing T and its subtree; otherwise, processes + * T and its subtree recursively. + * @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is + * UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`) + * has been marked as in effective on a TreeNode T, skips processing T and its + * subtree. Do not pass it if the rule is not purely functional and reads a + * varying initial state for different invocations. + * @return the stream of alternatives with a flag if any transformation was done + */ + def multiTransformDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId + )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[(BaseType, Boolean)] = { + if (!cond.apply(this) || isRuleIneffective(ruleId)) { + return Stream(this -> false) + } + + val afterRules = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, (_: BaseType) => Seq.empty) + } + // A stream of a tuple that contains: + // - a node that is either the transformed alternative of the current node or the current node, + // - a boolean flag if the node was actually transformed, + // - a boolean flag if a node's children needs to be transformed to add the node to the valid + // alternatives + val afterRulesStream = if (afterRules.isEmpty) { + // If the rule is not applied or returns with empty alternatives keep the original node + Stream((this, false, false)) + } else { + // If the rule is applied then use the returned alternatives. The alternatives can include the + // current node and we need to keep track of that + var foundEqual = false + afterRules.toStream.map { afterRule => + (if (this fastEquals afterRule) { + foundEqual = true + this + } else { + afterRule.copyTagsFrom(this) + afterRule + }, true, false) + }.append( + // If the current node is not a leaf node and the alternatives returned by the rule doesn't + // contain it then we need to add the current node to the stream, but require any of its + // child nodes to be transformed to keep it as a valid alternative + if (containsChild.nonEmpty && !foundEqual) { + Stream((this, false, true)) + } else { + Stream.empty + } + ) + } + + def generateChildrenSeq(children: Seq[BaseType]): Stream[(Seq[BaseType], Boolean)] = { + children.foldRight(Stream((Seq.empty[BaseType], false)))((child, childrenSeqStream) => + for { + (childrenSeq, childrenSeqChanged) <- childrenSeqStream + (newChild, childChanged) <- child.multiTransformDownWithPruning(cond, ruleId)(rule) + } yield (newChild +: childrenSeq) -> (childChanged || childrenSeqChanged) + ) + } + + afterRulesStream.flatMap { case (afterRule, transformed, childrenTransformRequired) => + if (afterRule.containsChild.nonEmpty) { + generateChildrenSeq(afterRule.children).collect { + case (newChildren, childrenTransformed) + if !childrenTransformRequired || childrenTransformed => + afterRule.withNewChildren(newChildren) -> (transformed || childrenTransformed) + } + } else { + Seq(afterRule -> transformed) + }.map { rewritten_plan => + if (this eq rewritten_plan) { + markRuleAsIneffective(ruleId) + } + rewritten_plan + } + } + } + /** * Returns a copy of this node where `f` has been applied to all the nodes in `children`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 286d3dddae6ea..f2cec5064fca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -977,4 +977,115 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(origin.context.summary.isEmpty) } } + + private def newErrorAfterStream(es: Expression*) = { + es.toStream.append( + throw new NoSuchElementException("Stream should not return more elements") + ) + } + + test("multiTransformDown generates all alternatives") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => + Seq(Literal(100), Literal(200), Literal(300)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200), Literal(300)) + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), cd) + assert(transformed === expected) + } + + test("multiTransformDown is lazy") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => newErrorAfterStream(Literal(10)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected = for { + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, Literal(10)), Literal(100)) + // We don't access alternatives for `b` after 10 and for `c` after 100 + assert(transformed.take(3) == expected) + intercept[NoSuchElementException] { + transformed.take(3 + 1).toList + } + + val transformed2 = e.multiTransformDown { + case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected2 = for { + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), Literal(100)) + // We don't access alternatives for `c` after 100 + assert(transformed2.take(3 * 3) === expected2) + intercept[NoSuchElementException] { + transformed.take(3 * 3 + 1).toList + } + } + + test("multiTransformDown rule return this") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s) + case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s) + case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200), a) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d"))) + b <- Seq(Literal(10), Literal(20), Literal("b")) + a <- Seq(Literal(1), Literal(2), Literal("a")) + } yield Add(Add(a, b), cd) + assert(transformed == expected) + } + + test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " + + "transformed") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case Add(StringLiteral("a"), StringLiteral("b"), _) => + Seq(Literal(11), Literal(12), Literal(21), Literal(22)) + case StringLiteral("a") => Seq(Literal(1), Literal(2)) + case StringLiteral("b") => Seq(Literal(10), Literal(20)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200)) + ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++ + (for { + b <- Seq(Literal(10), Literal(20)) + a <- Seq(Literal(1), Literal(2)) + } yield Add(a, b)) + } yield Add(ab, cd) + assert(transformed == expected) + } + + test("multiTransformDown non-leaf transformation if a descendant can be transformed too " + + "behaves like non-leaf returned itself") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case a @ Add(StringLiteral("a"), StringLiteral("b"), _) => + Seq(Literal(11), Literal(12), Literal(21), Literal(22), a) + case StringLiteral("a") => Seq(Literal(1), Literal(2)) + case StringLiteral("b") => Seq(Literal(10), Literal(20)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200)) + } + val expected = for { + cd <- Seq(Literal(100), Literal(200)) + ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++ + (for { + b <- Seq(Literal(10), Literal(20)) + a <- Seq(Literal(1), Literal(2)) + } yield Add(a, b)) + } yield Add(ab, cd) + assert(transformed == expected) + } } From b99f2f9a7c1515021294c25bbb727a120f28feed Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 14 Jan 2023 13:52:56 +0100 Subject: [PATCH 2/5] Add pruning option to multiTransform --- .../spark/sql/catalyst/trees/TreeNode.scala | 61 +++++++++++-------- .../sql/catalyst/trees/TreeNodeSuite.scala | 13 ++++ 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 1495172a13503..451c8d7a14345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -678,9 +678,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * To indicate that the original node without any transformation is a valid alternative the rule * can either: * - not apply or - * - return an empty `Seq` or * - a `Seq` that contains a node that is equal to the original node. * + * The rule can return `Seq.empty` to indicate that the original node should be pruned from the + * alternatives. + * * Please note that this function always consider the original node as a valid alternative (even * if the original node is not included in the returned `Seq`) if the rule can transform any of * the descendants of the node. E.g. consider a simple expression: @@ -716,38 +718,43 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } val afterRules = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, (_: BaseType) => Seq.empty) + rule.applyOrElse(this, (t: BaseType) => Seq(t)) } // A stream of a tuple that contains: // - a node that is either the transformed alternative of the current node or the current node, // - a boolean flag if the node was actually transformed, // - a boolean flag if a node's children needs to be transformed to add the node to the valid // alternatives - val afterRulesStream = if (afterRules.isEmpty) { - // If the rule is not applied or returns with empty alternatives keep the original node - Stream((this, false, false)) - } else { - // If the rule is applied then use the returned alternatives. The alternatives can include the - // current node and we need to keep track of that - var foundEqual = false - afterRules.toStream.map { afterRule => - (if (this fastEquals afterRule) { - foundEqual = true - this - } else { - afterRule.copyTagsFrom(this) - afterRule - }, true, false) - }.append( - // If the current node is not a leaf node and the alternatives returned by the rule doesn't - // contain it then we need to add the current node to the stream, but require any of its - // child nodes to be transformed to keep it as a valid alternative - if (containsChild.nonEmpty && !foundEqual) { - Stream((this, false, true)) - } else { - Stream.empty - } - ) + val afterRulesStream = afterRules match { + // If the rule returns with empty alternatives then prune + case Nil => Stream.empty + + // If the rule returns with a node equal to the original (or not applied) then keep the + // original node + case afterRule :: Nil if this fastEquals afterRule => Stream((this, false, false)) + + // If the rule is applied then use the returned alternatives + case _ => + // The alternatives can include the current node and we need to keep track of that + var foundEqual = false + afterRules.toStream.map { afterRule => + (if (this fastEquals afterRule) { + foundEqual = true + this + } else { + afterRule.copyTagsFrom(this) + afterRule + }, true, false) + }.append( + // If the current node is not a leaf node and the alternatives returned by the rule + // doesn't contain it then we need to add the current node to the stream, but require any + // of its child nodes to be transformed to keep it as a valid alternative + if (containsChild.nonEmpty && !foundEqual) { + Stream((this, false, true)) + } else { + Stream.empty + } + ) } def generateChildrenSeq(children: Seq[BaseType]): Stream[(Seq[BaseType], Boolean)] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index f2cec5064fca3..22661ce344cfa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -1088,4 +1088,17 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } yield Add(ab, cd) assert(transformed == expected) } + + test("multiTransformDown can prune") { + val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) + val transformed = e.multiTransformDown { + case StringLiteral("a") => Seq.empty + } + assert(transformed.isEmpty) + + val transformed2 = e.multiTransformDown { + case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty + } + assert(transformed2.isEmpty) + } } From 8de8f88222e98d2bfaa760cf30267ca8c7d59372 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 16 Jan 2023 14:53:48 +0100 Subject: [PATCH 3/5] change rule to require `Stream` of alternatives instead of `Seq`, add explicit flag to enable the `autoContinue` feature, add more examples, remove general versions to highlight this is a top-down rule --- .../spark/sql/catalyst/trees/TreeNode.scala | 152 +++++++++++------- .../sql/catalyst/trees/TreeNodeSuite.scala | 82 +++++----- 2 files changed, 139 insertions(+), 95 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 451c8d7a14345..12b6924d62fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -622,13 +622,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * Returns alternative copies of this node where `rule` has been recursively applied to the tree. * * Users should not expect a specific directionality. If a specific directionality is needed, - * multiTransformDown or multiTransformUp should be used. + * multiTransformDownWithPruning or multiTransformUpWithPruning should be used. * * @param rule a function used to generate transformed alternatives for a node * @return the stream of alternatives */ - def multiTransform(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { - multiTransformDown(rule) + def multiTransformDown( + rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Returns alternative copies of this node where `rule` has been recursively applied to the tree. + * + * Users should not expect a specific directionality. If a specific directionality is needed, + * multiTransformDownWithPruning or multiTransformUpWithPruning should be used. + * + * @param rule a function used to generate transformed alternatives for a node and the + * `autoContinue` flag + * @return the stream of alternatives + */ + def multiTransformDownWithContinuation( + rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = { + multiTransformDownWithContinuationAndPruning(AlwaysProcess.fn, UnknownRuleId)(rule) } /** @@ -648,22 +664,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * varying initial state for different invocations. * @return the stream of alternatives */ - def multiTransformWithPruning( + def multiTransformDownWithPruning( cond: TreePatternBits => Boolean, ruleId: RuleId = UnknownRuleId - )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { - multiTransformDownWithPruning(cond, ruleId)(rule).map(_._1) - } - - /** - * Returns alternative copies of this node where `rule` has been recursively applied to it and all - * of its children (pre-order). - * - * @param rule the function used to generate transformed alternatives for a node - * @return the stream of alternatives - */ - def multiTransformDown(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = { - multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule).map(_._1) + )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { + multiTransformDownWithContinuationAndPruning(cond, ruleId)(rule.andThen(_ -> false)) } /** @@ -675,30 +680,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * lazy `Stream` to be able to limit the number of alternatives generated at the caller side as * needed. * - * To indicate that the original node without any transformation is a valid alternative the rule - * can either: - * - not apply or - * - a `Seq` that contains a node that is equal to the original node. + * The rule should not apply to indicate that the original node without any transformation is a + * valid alternative. * - * The rule can return `Seq.empty` to indicate that the original node should be pruned from the - * alternatives. + * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this + * case `multiTransform` returns an empty `Stream`. + * + * Please consider the following examples of `input.multiTransform(rule)`: + * + * We have an input expression: + * `Add(a, b)` + * + * 1. + * We have a simple rule: + * `a` => `Stream(1, 2)` + * `b` => `Stream(10, 20)` + * `Add(a, b)` => `Stream(11, 12, 21, 22)` + * + * The output is: + * `Stream(11, 12, 21, 22)` + * + * 2. + * In the previous example if we want to generate alternatives of `a` and `b` too then we need to + * explicitly add the original `Add(a, b)` expression to the rule: + * `a` => `Stream(1, 2)` + * `b` => `Stream(10, 20)` + * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))` + * + * The output is: + * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` + * + * 3. + * It is not always easy to determine if we will do any child expression mapping but we can enable + * the `autoContinue` flag to get the same result: + * `a` => `(Stream(1, 2), false)` + * `b` => `(Stream(10, 20), false)` + * `Add(a, b)` => `(Stream(11, 12, 21, 22), true)` (Note the `true` flag and the missing + * `Add(a, b)`) + * The output is the same as in 2.: + * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` * - * Please note that this function always consider the original node as a valid alternative (even - * if the original node is not included in the returned `Seq`) if the rule can transform any of - * the descendants of the node. E.g. consider a simple expression: - * `Add(a, b)` - * and a rule that returns: - * `Seq(1, 2)` for `a` and - * `Seq(10, 20)` for `b` and - * `Seq(11, 12, 21, 22)` for `Add(a, b)` (note that the original `Add(a, b)` is not returned) - * then the result of `multiTransform` is: - * `Seq(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`. * This feature makes the usage of `multiTransform` easier as a non-leaf transforming rule doesn't * need to take into account that it can transform a descendant node of the non-leaf node as well * and so it doesn't need return the non-leaf node itself in the list of alternatives to not stop * generating alternatives. * - * @param rule a function used to generate transformed alternatives for a node + * @param rule a function used to generate transformed alternatives for a node and the + * `autoContinue` flag * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false * on a TreeNode T, skips processing T and its subtree; otherwise, processes * T and its subtree recursively. @@ -707,37 +735,48 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * has been marked as in effective on a TreeNode T, skips processing T and its * subtree. Do not pass it if the rule is not purely functional and reads a * varying initial state for different invocations. - * @return the stream of alternatives with a flag if any transformation was done + * @return the stream of alternatives */ - def multiTransformDownWithPruning( + def multiTransformDownWithContinuationAndPruning( cond: TreePatternBits => Boolean, ruleId: RuleId = UnknownRuleId - )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[(BaseType, Boolean)] = { + )(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = { + multiTransformDownHelper(cond, ruleId)(rule).map(_._1) + } + + private def multiTransformDownHelper( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId + )(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[(BaseType, Boolean)] = { if (!cond.apply(this) || isRuleIneffective(ruleId)) { return Stream(this -> false) } - val afterRules = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, (t: BaseType) => Seq(t)) + var ruleApplied = true + val (afterRules, autoContinue) = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, (_: BaseType) => { + ruleApplied = false + Stream.empty -> false + }) } // A stream of a tuple that contains: // - a node that is either the transformed alternative of the current node or the current node, // - a boolean flag if the node was actually transformed, // - a boolean flag if a node's children needs to be transformed to add the node to the valid // alternatives - val afterRulesStream = afterRules match { - // If the rule returns with empty alternatives then prune - case Nil => Stream.empty - - // If the rule returns with a node equal to the original (or not applied) then keep the - // original node - case afterRule :: Nil if this fastEquals afterRule => Stream((this, false, false)) - - // If the rule is applied then use the returned alternatives - case _ => + val afterRulesStream = if (afterRules.isEmpty) { + if (ruleApplied) { + // If the rule returned with empty alternatives then prune + Stream.empty + } else { + // If the rule was not applied then keep the original node + Stream((this, false, false)) + } + } else { + // If the rule was applied then use the returned alternatives // The alternatives can include the current node and we need to keep track of that var foundEqual = false - afterRules.toStream.map { afterRule => + afterRules.map { afterRule => (if (this fastEquals afterRule) { foundEqual = true this @@ -746,10 +785,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre afterRule }, true, false) }.append( - // If the current node is not a leaf node and the alternatives returned by the rule - // doesn't contain it then we need to add the current node to the stream, but require any - // of its child nodes to be transformed to keep it as a valid alternative - if (containsChild.nonEmpty && !foundEqual) { + // If autoContinue is enabled and the current node is not a leaf node and the alternatives + // returned by the rule doesn't contain the current node then we need to add the current + // node to the stream, but require any of its child nodes to be transformed to keep it as + // a valid alternative + if (autoContinue && containsChild.nonEmpty && !foundEqual) { Stream((this, false, true)) } else { Stream.empty @@ -761,7 +801,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre children.foldRight(Stream((Seq.empty[BaseType], false)))((child, childrenSeqStream) => for { (childrenSeq, childrenSeqChanged) <- childrenSeqStream - (newChild, childChanged) <- child.multiTransformDownWithPruning(cond, ruleId)(rule) + (newChild, childChanged) <- child.multiTransformDownHelper(cond, ruleId)(rule) } yield (newChild +: childrenSeq) -> (childChanged || childrenSeqChanged) ) } @@ -774,7 +814,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre afterRule.withNewChildren(newChildren) -> (transformed || childrenTransformed) } } else { - Seq(afterRule -> transformed) + Stream(afterRule -> transformed) }.map { rewritten_plan => if (this eq rewritten_plan) { markRuleAsIneffective(ruleId) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 22661ce344cfa..f49283b90089e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -987,10 +987,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown generates all alternatives") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) - case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30)) + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) case Add(StringLiteral("c"), StringLiteral("d"), _) => - Seq(Literal(100), Literal(200), Literal(300)) + Stream(Literal(100), Literal(200), Literal(300)) } val expected = for { cd <- Seq(Literal(100), Literal(200), Literal(300)) @@ -1002,10 +1002,11 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown is lazy") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) - val transformed = e.multiTransformDown { - case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) - case StringLiteral("b") => newErrorAfterStream(Literal(10)) - case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + val transformed = e.multiTransformDownWithContinuation { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true + case StringLiteral("b") => newErrorAfterStream(Literal(10)) -> true + case Add(StringLiteral("c"), StringLiteral("d"), _) => + newErrorAfterStream(Literal(100)) -> true } val expected = for { a <- Seq(Literal(1), Literal(2), Literal(3)) @@ -1016,28 +1017,30 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { transformed.take(3 + 1).toList } - val transformed2 = e.multiTransformDown { - case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3)) - case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30)) - case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) - } - val expected2 = for { - b <- Seq(Literal(10), Literal(20), Literal(30)) - a <- Seq(Literal(1), Literal(2), Literal(3)) - } yield Add(Add(a, b), Literal(100)) - // We don't access alternatives for `c` after 100 - assert(transformed2.take(3 * 3) === expected2) - intercept[NoSuchElementException] { - transformed.take(3 * 3 + 1).toList - } +// val transformed2 = e.multiTransformDownWithContinuation { +// case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true +// case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) -> true +// case Add(StringLiteral("c"), StringLiteral("d"), _) => +// newErrorAfterStream(Literal(100)) -> true +// } +// val expected2 = for { +// b <- Seq(Literal(10), Literal(20), Literal(30)) +// a <- Seq(Literal(1), Literal(2), Literal(3)) +// } yield Add(Add(a, b), Literal(100)) +// // We don't access alternatives for `c` after 100 +// assert(transformed2.take(3 * 3) === expected2) +// intercept[NoSuchElementException] { +// transformed.take(3 * 3 + 1).toList +// } } test("multiTransformDown rule return this") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s) - case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s) - case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200), a) + case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s) + case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s) + case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => + Stream(Literal(100), Literal(200), a) } val expected = for { cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d"))) @@ -1047,15 +1050,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(transformed == expected) } - test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " + - "transformed") { + test("multiTransformDownWithContinuation doesn't stop generating alternatives of descendants " + + "when non-leaf is transformed but the itself is not in the alternatives") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) - val transformed = e.multiTransformDown { + val transformed = e.multiTransformDownWithContinuation { case Add(StringLiteral("a"), StringLiteral("b"), _) => - Seq(Literal(11), Literal(12), Literal(21), Literal(22)) - case StringLiteral("a") => Seq(Literal(1), Literal(2)) - case StringLiteral("b") => Seq(Literal(10), Literal(20)) - case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200)) + Stream(Literal(11), Literal(12), Literal(21), Literal(22)) -> true + case StringLiteral("a") => Stream(Literal(1), Literal(2)) -> true + case StringLiteral("b") => Stream(Literal(10), Literal(20)) -> true + case Add(StringLiteral("c"), StringLiteral("d"), _) => + Stream(Literal(100), Literal(200)) -> true } val expected = for { cd <- Seq(Literal(100), Literal(200)) @@ -1068,15 +1072,15 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(transformed == expected) } - test("multiTransformDown non-leaf transformation if a descendant can be transformed too " + - "behaves like non-leaf returned itself") { + test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " + + "transformed and itself is in the alternatives") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { case a @ Add(StringLiteral("a"), StringLiteral("b"), _) => - Seq(Literal(11), Literal(12), Literal(21), Literal(22), a) - case StringLiteral("a") => Seq(Literal(1), Literal(2)) - case StringLiteral("b") => Seq(Literal(10), Literal(20)) - case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200)) + Stream(Literal(11), Literal(12), Literal(21), Literal(22), a) + case StringLiteral("a") => Stream(Literal(1), Literal(2)) + case StringLiteral("b") => Stream(Literal(10), Literal(20)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200)) } val expected = for { cd <- Seq(Literal(100), Literal(200)) @@ -1092,12 +1096,12 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown can prune") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) val transformed = e.multiTransformDown { - case StringLiteral("a") => Seq.empty + case StringLiteral("a") => Stream.empty } assert(transformed.isEmpty) val transformed2 = e.multiTransformDown { - case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty + case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty } assert(transformed2.isEmpty) } From afcef4e95fbfec085b0d81dbf9929ab246156f69 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 17 Jan 2023 10:02:15 +0100 Subject: [PATCH 4/5] fix documentation, remove autoContinue feature, revisit when we can mark a rule ineffective, move generateChildrenSeq as a top level helper as it will help with early pruning --- .../spark/sql/catalyst/trees/TreeNode.scala | 178 +++++------------- .../sql/catalyst/trees/TreeNodeSuite.scala | 60 ++---- 2 files changed, 68 insertions(+), 170 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 12b6924d62fa3..51b2d56831f6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -619,12 +619,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre } /** - * Returns alternative copies of this node where `rule` has been recursively applied to the tree. - * - * Users should not expect a specific directionality. If a specific directionality is needed, - * multiTransformDownWithPruning or multiTransformUpWithPruning should be used. + * Returns alternative copies of this node where `rule` has been recursively applied to it and all + * of its children (pre-order). * - * @param rule a function used to generate transformed alternatives for a node + * @param rule a function used to generate alternatives for a node * @return the stream of alternatives */ def multiTransformDown( @@ -632,59 +630,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) } - /** - * Returns alternative copies of this node where `rule` has been recursively applied to the tree. - * - * Users should not expect a specific directionality. If a specific directionality is needed, - * multiTransformDownWithPruning or multiTransformUpWithPruning should be used. - * - * @param rule a function used to generate transformed alternatives for a node and the - * `autoContinue` flag - * @return the stream of alternatives - */ - def multiTransformDownWithContinuation( - rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = { - multiTransformDownWithContinuationAndPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Returns alternative copies of this node where `rule` has been recursively applied to the tree. - * - * Users should not expect a specific directionality. If a specific directionality is needed, - * multiTransformDownWithPruning or multiTransformUpWithPruning should be used. - * - * @param rule a function used to generate transformed alternatives for a node - * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false - * on a TreeNode T, skips processing T and its subtree; otherwise, processes - * T and its subtree recursively. - * @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is - * UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`) - * has been marked as in effective on a TreeNode T, skips processing T and its - * subtree. Do not pass it if the rule is not purely functional and reads a - * varying initial state for different invocations. - * @return the stream of alternatives - */ - def multiTransformDownWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId - )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { - multiTransformDownWithContinuationAndPruning(cond, ruleId)(rule.andThen(_ -> false)) - } - /** * Returns alternative copies of this node where `rule` has been recursively applied to it and all * of its children (pre-order). * * As it is very easy to generate enormous number of alternatives when the input tree is huge or - * when the rule returns large number of alternatives, this function returns the alternatives as a - * lazy `Stream` to be able to limit the number of alternatives generated at the caller side as - * needed. + * when the rule returns many alternatives for many nodes, this function returns the alternatives + * as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side + * as needed. * - * The rule should not apply to indicate that the original node without any transformation is a - * valid alternative. + * The rule should not apply or can return a one element stream of original node to indicate that + * the original node without any transformation is a valid alternative. * * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this - * case `multiTransform` returns an empty `Stream`. + * case `multiTransform()` returns an empty `Stream`. * * Please consider the following examples of `input.multiTransform(rule)`: * @@ -710,23 +669,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * The output is: * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` * - * 3. - * It is not always easy to determine if we will do any child expression mapping but we can enable - * the `autoContinue` flag to get the same result: - * `a` => `(Stream(1, 2), false)` - * `b` => `(Stream(10, 20), false)` - * `Add(a, b)` => `(Stream(11, 12, 21, 22), true)` (Note the `true` flag and the missing - * `Add(a, b)`) - * The output is the same as in 2.: - * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` - * - * This feature makes the usage of `multiTransform` easier as a non-leaf transforming rule doesn't - * need to take into account that it can transform a descendant node of the non-leaf node as well - * and so it doesn't need return the non-leaf node itself in the list of alternatives to not stop - * generating alternatives. - * - * @param rule a function used to generate transformed alternatives for a node and the - * `autoContinue` flag + * @param rule a function used to generate alternatives for a node * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false * on a TreeNode T, skips processing T and its subtree; otherwise, processes * T and its subtree recursively. @@ -737,93 +680,72 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * varying initial state for different invocations. * @return the stream of alternatives */ - def multiTransformDownWithContinuationAndPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId - )(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = { - multiTransformDownHelper(cond, ruleId)(rule).map(_._1) - } - - private def multiTransformDownHelper( + def multiTransformDownWithPruning( cond: TreePatternBits => Boolean, ruleId: RuleId = UnknownRuleId - )(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[(BaseType, Boolean)] = { + )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { if (!cond.apply(this) || isRuleIneffective(ruleId)) { - return Stream(this -> false) + return Stream(this) } + // We could return `Stream(this)` if the `rule` doesn't apply and handle both + // - the doesn't apply + // - and the rule returns a one element `Stream(originalNode)` + // cases together. But, unfortunately it doesn't seem like there is a way to match on a one + // element stream without eagerly computing the tail head. So this contradicts with the purpose + // of only taking the necessary elements from the alternatives. I.e. the + // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. + // Please note that this behaviour has a downside as well that we can only mark the rule on the + // original node ineffective if the rule didn't match. var ruleApplied = true - val (afterRules, autoContinue) = CurrentOrigin.withOrigin(origin) { + val afterRules = CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, (_: BaseType) => { ruleApplied = false - Stream.empty -> false + Stream.empty }) } - // A stream of a tuple that contains: - // - a node that is either the transformed alternative of the current node or the current node, - // - a boolean flag if the node was actually transformed, - // - a boolean flag if a node's children needs to be transformed to add the node to the valid - // alternatives + val afterRulesStream = if (afterRules.isEmpty) { if (ruleApplied) { // If the rule returned with empty alternatives then prune Stream.empty } else { // If the rule was not applied then keep the original node - Stream((this, false, false)) + this.markRuleAsIneffective(ruleId) + Stream(this) } } else { - // If the rule was applied then use the returned alternatives - // The alternatives can include the current node and we need to keep track of that - var foundEqual = false - afterRules.map { afterRule => - (if (this fastEquals afterRule) { - foundEqual = true - this - } else { - afterRule.copyTagsFrom(this) - afterRule - }, true, false) - }.append( - // If autoContinue is enabled and the current node is not a leaf node and the alternatives - // returned by the rule doesn't contain the current node then we need to add the current - // node to the stream, but require any of its child nodes to be transformed to keep it as - // a valid alternative - if (autoContinue && containsChild.nonEmpty && !foundEqual) { - Stream((this, false, true)) - } else { - Stream.empty - } - ) - } - - def generateChildrenSeq(children: Seq[BaseType]): Stream[(Seq[BaseType], Boolean)] = { - children.foldRight(Stream((Seq.empty[BaseType], false)))((child, childrenSeqStream) => - for { - (childrenSeq, childrenSeqChanged) <- childrenSeqStream - (newChild, childChanged) <- child.multiTransformDownHelper(cond, ruleId)(rule) - } yield (newChild +: childrenSeq) -> (childChanged || childrenSeqChanged) - ) + // If the rule was applied then use the returned alternatives + afterRules.map { afterRule => + if (this fastEquals afterRule) { + this + } else { + afterRule.copyTagsFrom(this) + afterRule + } + } } - afterRulesStream.flatMap { case (afterRule, transformed, childrenTransformRequired) => + afterRulesStream.flatMap { afterRule => if (afterRule.containsChild.nonEmpty) { - generateChildrenSeq(afterRule.children).collect { - case (newChildren, childrenTransformed) - if !childrenTransformRequired || childrenTransformed => - afterRule.withNewChildren(newChildren) -> (transformed || childrenTransformed) - } + generateChildrenSeq( + afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule))) + .map(afterRule.withNewChildren) } else { - Stream(afterRule -> transformed) - }.map { rewritten_plan => - if (this eq rewritten_plan) { - markRuleAsIneffective(ruleId) - } - rewritten_plan + Stream(afterRule) } } } + private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = { + childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) => + for { + childrenSeq <- childrenSeqStream + child <- childrenStream + } yield child +: childrenSeq + ) + } + /** * Returns a copy of this node where `f` has been applied to all the nodes in `children`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index f49283b90089e..ac28917675e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -1002,11 +1002,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { test("multiTransformDown is lazy") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) - val transformed = e.multiTransformDownWithContinuation { - case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true - case StringLiteral("b") => newErrorAfterStream(Literal(10)) -> true - case Add(StringLiteral("c"), StringLiteral("d"), _) => - newErrorAfterStream(Literal(100)) -> true + val transformed = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => newErrorAfterStream(Literal(10)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) } val expected = for { a <- Seq(Literal(1), Literal(2), Literal(3)) @@ -1017,21 +1016,20 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { transformed.take(3 + 1).toList } -// val transformed2 = e.multiTransformDownWithContinuation { -// case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true -// case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) -> true -// case Add(StringLiteral("c"), StringLiteral("d"), _) => -// newErrorAfterStream(Literal(100)) -> true -// } -// val expected2 = for { -// b <- Seq(Literal(10), Literal(20), Literal(30)) -// a <- Seq(Literal(1), Literal(2), Literal(3)) -// } yield Add(Add(a, b), Literal(100)) -// // We don't access alternatives for `c` after 100 -// assert(transformed2.take(3 * 3) === expected2) -// intercept[NoSuchElementException] { -// transformed.take(3 * 3 + 1).toList -// } + val transformed2 = e.multiTransformDown { + case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) + case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) + case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100)) + } + val expected2 = for { + b <- Seq(Literal(10), Literal(20), Literal(30)) + a <- Seq(Literal(1), Literal(2), Literal(3)) + } yield Add(Add(a, b), Literal(100)) + // We don't access alternatives for `c` after 100 + assert(transformed2.take(3 * 3) === expected2) + intercept[NoSuchElementException] { + transformed.take(3 * 3 + 1).toList + } } test("multiTransformDown rule return this") { @@ -1050,28 +1048,6 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(transformed == expected) } - test("multiTransformDownWithContinuation doesn't stop generating alternatives of descendants " + - "when non-leaf is transformed but the itself is not in the alternatives") { - val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) - val transformed = e.multiTransformDownWithContinuation { - case Add(StringLiteral("a"), StringLiteral("b"), _) => - Stream(Literal(11), Literal(12), Literal(21), Literal(22)) -> true - case StringLiteral("a") => Stream(Literal(1), Literal(2)) -> true - case StringLiteral("b") => Stream(Literal(10), Literal(20)) -> true - case Add(StringLiteral("c"), StringLiteral("d"), _) => - Stream(Literal(100), Literal(200)) -> true - } - val expected = for { - cd <- Seq(Literal(100), Literal(200)) - ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++ - (for { - b <- Seq(Literal(10), Literal(20)) - a <- Seq(Literal(1), Literal(2)) - } yield Add(a, b)) - } yield Add(ab, cd) - assert(transformed == expected) - } - test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " + "transformed and itself is in the alternatives") { val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d"))) From 60323c8de656643987ab648b42a341cbdc9c3ed7 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 17 Jan 2023 10:46:04 +0100 Subject: [PATCH 5/5] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala Co-authored-by: Wenchen Fan --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 51b2d56831f6f..dc64e5e256052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -645,7 +645,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this * case `multiTransform()` returns an empty `Stream`. * - * Please consider the following examples of `input.multiTransform(rule)`: + * Please consider the following examples of `input.multiTransformDown(rule)`: * * We have an input expression: * `Add(a, b)`