-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate alternatives #38034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
52546d0
b99f2f9
8de8f88
afcef4e
60323c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns alternative copies of this node where `rule` has been recursively applied to it and all | ||
| * of its children (pre-order). | ||
| * | ||
| * @param rule a function used to generate alternatives for a node | ||
| * @return the stream of alternatives | ||
| */ | ||
| def multiTransformDown( | ||
| rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { | ||
| multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) | ||
| } | ||
|
|
||
| /** | ||
| * Returns alternative copies of this node where `rule` has been recursively applied to it and all | ||
| * of its children (pre-order). | ||
| * | ||
| * As it is very easy to generate enormous number of alternatives when the input tree is huge or | ||
| * when the rule returns many alternatives for many nodes, this function returns the alternatives | ||
| * as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side | ||
| * as needed. | ||
| * | ||
| * The rule should not apply or can return a one element stream of original node to indicate that | ||
| * the original node without any transformation is a valid alternative. | ||
| * | ||
| * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this | ||
| * case `multiTransform()` returns an empty `Stream`. | ||
| * | ||
| * Please consider the following examples of `input.multiTransform(rule)`: | ||
| * | ||
| * We have an input expression: | ||
| * `Add(a, b)` | ||
| * | ||
| * 1. | ||
| * We have a simple rule: | ||
| * `a` => `Stream(1, 2)` | ||
| * `b` => `Stream(10, 20)` | ||
| * `Add(a, b)` => `Stream(11, 12, 21, 22)` | ||
| * | ||
| * The output is: | ||
| * `Stream(11, 12, 21, 22)` | ||
| * | ||
| * 2. | ||
| * In the previous example if we want to generate alternatives of `a` and `b` too then we need to | ||
| * explicitly add the original `Add(a, b)` expression to the rule: | ||
| * `a` => `Stream(1, 2)` | ||
| * `b` => `Stream(10, 20)` | ||
| * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))` | ||
| * | ||
| * The output is: | ||
| * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))` | ||
| * | ||
| * @param rule a function used to generate alternatives for a node | ||
| * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false | ||
| * on a TreeNode T, skips processing T and its subtree; otherwise, processes | ||
| * T and its subtree recursively. | ||
| * @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is | ||
| * UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`) | ||
| * has been marked as in effective on a TreeNode T, skips processing T and its | ||
| * subtree. Do not pass it if the rule is not purely functional and reads a | ||
| * varying initial state for different invocations. | ||
| * @return the stream of alternatives | ||
| */ | ||
| def multiTransformDownWithPruning( | ||
| cond: TreePatternBits => Boolean, | ||
| ruleId: RuleId = UnknownRuleId | ||
| )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = { | ||
| if (!cond.apply(this) || isRuleIneffective(ruleId)) { | ||
| return Stream(this) | ||
| } | ||
|
|
||
| // We could return `Stream(this)` if the `rule` doesn't apply and handle both | ||
| // - the doesn't apply | ||
| // - and the rule returns a one element `Stream(originalNode)` | ||
| // cases together. But, unfortunately it doesn't seem like there is a way to match on a one | ||
| // element stream without eagerly computing the tail head. So this contradicts with the purpose | ||
| // of only taking the necessary elements from the alternatives. I.e. the | ||
| // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. | ||
| // Please note that this behaviour has a downside as well that we can only mark the rule on the | ||
| // original node ineffective if the rule didn't match. | ||
| var ruleApplied = true | ||
| val afterRules = CurrentOrigin.withOrigin(origin) { | ||
| rule.applyOrElse(this, (_: BaseType) => { | ||
| ruleApplied = false | ||
| Stream.empty | ||
| }) | ||
| } | ||
|
|
||
| val afterRulesStream = if (afterRules.isEmpty) { | ||
| if (ruleApplied) { | ||
| // If the rule returned with empty alternatives then prune | ||
| Stream.empty | ||
| } else { | ||
| // If the rule was not applied then keep the original node | ||
| this.markRuleAsIneffective(ruleId) | ||
| Stream(this) | ||
| } | ||
| } else { | ||
| // If the rule was applied then use the returned alternatives | ||
| afterRules.map { afterRule => | ||
| if (this fastEquals afterRule) { | ||
| this | ||
| } else { | ||
| afterRule.copyTagsFrom(this) | ||
| afterRule | ||
| } | ||
| } | ||
| } | ||
|
|
||
| afterRulesStream.flatMap { afterRule => | ||
| if (afterRule.containsChild.nonEmpty) { | ||
| generateChildrenSeq( | ||
| afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule))) | ||
| .map(afterRule.withNewChildren) | ||
| } else { | ||
| Stream(afterRule) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = { | ||
| childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we fold from right to left?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to generate alternatives for the first children of an expression first. |
||
| for { | ||
| childrenSeq <- childrenSeqStream | ||
| child <- childrenStream | ||
| } yield child +: childrenSeq | ||
| ) | ||
| } | ||
|
|
||
| /** | ||
| * Returns a copy of this node where `f` has been applied to all the nodes in `children`. | ||
| */ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.