Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,212 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
}
}

/**
* Returns alternative copies of this node where `rule` has been recursively applied to the tree.
*
* Users should not expect a specific directionality. If a specific directionality is needed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment needs an update. We should also explain why only the down direction is provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified the scala docs, let me know if we need more defails or fixes.

* multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
*
* @param rule a function used to generate transformed alternatives for a node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* @param rule a function used to generate transformed alternatives for a node
* @param rule a function used to generate alternatives for a node

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

* @return the stream of alternatives
*/
def multiTransformDown(
rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}

/**
* Returns alternative copies of this node where `rule` has been recursively applied to the tree.
*
* Users should not expect a specific directionality. If a specific directionality is needed,
* multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

*
* @param rule a function used to generate transformed alternatives for a node and the
* `autoContinue` flag
* @return the stream of alternatives
*/
def multiTransformDownWithContinuation(
rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = {
multiTransformDownWithContinuationAndPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}

/**
* Returns alternative copies of this node where `rule` has been recursively applied to the tree.
*
* Users should not expect a specific directionality. If a specific directionality is needed,
* multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
*
* @param rule a function used to generate transformed alternatives for a node
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
* on a TreeNode T, skips processing T and its subtree; otherwise, processes
* T and its subtree recursively.
* @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is
* UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`)
* has been marked as in effective on a TreeNode T, skips processing T and its
* subtree. Do not pass it if the rule is not purely functional and reads a
* varying initial state for different invocations.
* @return the stream of alternatives
*/
def multiTransformDownWithPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
)(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
multiTransformDownWithContinuationAndPruning(cond, ruleId)(rule.andThen(_ -> false))
}

/**
* Returns alternative copies of this node where `rule` has been recursively applied to it and all
* of its children (pre-order).
*
* As it is very easy to generate enormous number of alternatives when the input tree is huge or
* when the rule returns large number of alternatives, this function returns the alternatives as a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* when the rule returns large number of alternatives, this function returns the alternatives as a
* when the rule returns many alternatives for many nodes, this function returns the alternatives as a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

* lazy `Stream` to be able to limit the number of alternatives generated at the caller side as
* needed.
*
* The rule should not apply to indicate that the original node without any transformation is a
* valid alternative.
*
* The rule can return `Stream.empty` to indicate that the original node should be pruned. In this
* case `multiTransform` returns an empty `Stream`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also mention, if the rule applies but you want the not-apply behavior, you can just return Seq(originalNode)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

*
* Please consider the following examples of `input.multiTransform(rule)`:
*
* We have an input expression:
* `Add(a, b)`
*
* 1.
* We have a simple rule:
* `a` => `Stream(1, 2)`
* `b` => `Stream(10, 20)`
* `Add(a, b)` => `Stream(11, 12, 21, 22)`
*
* The output is:
* `Stream(11, 12, 21, 22)`
*
* 2.
* In the previous example if we want to generate alternatives of `a` and `b` too then we need to
* explicitly add the original `Add(a, b)` expression to the rule:
* `a` => `Stream(1, 2)`
* `b` => `Stream(10, 20)`
* `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
*
* The output is:
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
*
* 3.
* It is not always easy to determine if we will do any child expression mapping but we can enable
* the `autoContinue` flag to get the same result:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about making multiTransform too complicated. Given the only use case for now is projecting expressions such as output partitions/orderings, can we simplify the rule a little bit? My preference is to remove this autoContinue flag and fully rely on the callers.

It is not always easy to determine if we will do any child expression mapping

If we have a real SQL use case, I'm happy to change my mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed. And I don't have any other use case either, so I will remove that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, removed.

* `a` => `(Stream(1, 2), false)`
* `b` => `(Stream(10, 20), false)`
* `Add(a, b)` => `(Stream(11, 12, 21, 22), true)` (Note the `true` flag and the missing
* `Add(a, b)`)
* The output is the same as in 2.:
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
*
* This feature makes the usage of `multiTransform` easier as a non-leaf transforming rule doesn't
* need to take into account that it can transform a descendant node of the non-leaf node as well
* and so it doesn't need return the non-leaf node itself in the list of alternatives to not stop
* generating alternatives.
*
* @param rule a function used to generate transformed alternatives for a node and the
* `autoContinue` flag
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
* on a TreeNode T, skips processing T and its subtree; otherwise, processes
* T and its subtree recursively.
* @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is
* UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`)
* has been marked as in effective on a TreeNode T, skips processing T and its
* subtree. Do not pass it if the rule is not purely functional and reads a
* varying initial state for different invocations.
* @return the stream of alternatives
*/
def multiTransformDownWithContinuationAndPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
)(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = {
multiTransformDownHelper(cond, ruleId)(rule).map(_._1)
}

private def multiTransformDownHelper(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
)(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[(BaseType, Boolean)] = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return Stream(this -> false)
}

var ruleApplied = true
val (afterRules, autoContinue) = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, (_: BaseType) => {
ruleApplied = false
Stream.empty -> false
})
}
// A stream of a tuple that contains:
// - a node that is either the transformed alternative of the current node or the current node,
// - a boolean flag if the node was actually transformed,
// - a boolean flag if a node's children needs to be transformed to add the node to the valid
// alternatives
val afterRulesStream = if (afterRules.isEmpty) {
if (ruleApplied) {
// If the rule returned with empty alternatives then prune
Stream.empty
} else {
// If the rule was not applied then keep the original node
Stream((this, false, false))
}
} else {
// If the rule was applied then use the returned alternatives
// The alternatives can include the current node and we need to keep track of that
var foundEqual = false
afterRules.map { afterRule =>
(if (this fastEquals afterRule) {
foundEqual = true
this
} else {
afterRule.copyTagsFrom(this)
afterRule
}, true, false)
}.append(
// If autoContinue is enabled and the current node is not a leaf node and the alternatives
// returned by the rule doesn't contain the current node then we need to add the current
// node to the stream, but require any of its child nodes to be transformed to keep it as
// a valid alternative
if (autoContinue && containsChild.nonEmpty && !foundEqual) {
Stream((this, false, true))
} else {
Stream.empty
}
)
}

def generateChildrenSeq(children: Seq[BaseType]): Stream[(Seq[BaseType], Boolean)] = {
children.foldRight(Stream((Seq.empty[BaseType], false)))((child, childrenSeqStream) =>
for {
(childrenSeq, childrenSeqChanged) <- childrenSeqStream
(newChild, childChanged) <- child.multiTransformDownHelper(cond, ruleId)(rule)
} yield (newChild +: childrenSeq) -> (childChanged || childrenSeqChanged)
)
}

afterRulesStream.flatMap { case (afterRule, transformed, childrenTransformRequired) =>
if (afterRule.containsChild.nonEmpty) {
generateChildrenSeq(afterRule.children).collect {
case (newChildren, childrenTransformed)
if !childrenTransformRequired || childrenTransformed =>
afterRule.withNewChildren(newChildren) -> (transformed || childrenTransformed)
}
} else {
Stream(afterRule -> transformed)
}.map { rewritten_plan =>
if (this eq rewritten_plan) {
markRuleAsIneffective(ruleId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if my understanding is correct, the only way to mark this rule as ineffective is that it returns stream with one alternative which is eq to this ? then why we use map here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks, this looks wrong. I will fix it with other requests today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest commit fixes this, although there is relevant issue that I'm not sure we can mark the rule ineffective if the one element original stream is returned:
https://github.com/apache/spark/pull/38034/files#diff-94575875fbf007fdaf43e4946c69c18649294fed974a46816ab1986f6350541bR691-R699

Copy link
Contributor

@cloud-fan cloud-fan Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK. We can mark the rule as ineffective only if the partial function does not apply. Once it applies, no matter what it returns, it's effective.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems fine

}
rewritten_plan
}
}
}

/**
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -977,4 +977,132 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
assert(origin.context.summary.isEmpty)
}
}

private def newErrorAfterStream(es: Expression*) = {
es.toStream.append(
throw new NoSuchElementException("Stream should not return more elements")
)
}

test("multiTransformDown generates all alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200), Literal(300))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Literal(300))
b <- Seq(Literal(10), Literal(20), Literal(30))
a <- Seq(Literal(1), Literal(2), Literal(3))
} yield Add(Add(a, b), cd)
assert(transformed === expected)
}

test("multiTransformDown is lazy") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDownWithContinuation {
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true
case StringLiteral("b") => newErrorAfterStream(Literal(10)) -> true
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
newErrorAfterStream(Literal(100)) -> true
}
val expected = for {
a <- Seq(Literal(1), Literal(2), Literal(3))
} yield Add(Add(a, Literal(10)), Literal(100))
// We don't access alternatives for `b` after 10 and for `c` after 100
assert(transformed.take(3) == expected)
intercept[NoSuchElementException] {
transformed.take(3 + 1).toList
}

// val transformed2 = e.multiTransformDownWithContinuation {
// case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true
// case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) -> true
// case Add(StringLiteral("c"), StringLiteral("d"), _) =>
// newErrorAfterStream(Literal(100)) -> true
// }
// val expected2 = for {
// b <- Seq(Literal(10), Literal(20), Literal(30))
// a <- Seq(Literal(1), Literal(2), Literal(3))
// } yield Add(Add(a, b), Literal(100))
// // We don't access alternatives for `c` after 100
// assert(transformed2.take(3 * 3) === expected2)
// intercept[NoSuchElementException] {
// transformed.take(3 * 3 + 1).toList
// }
}

test("multiTransformDown rule return this") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200), a)
}
val expected = for {
cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
b <- Seq(Literal(10), Literal(20), Literal("b"))
a <- Seq(Literal(1), Literal(2), Literal("a"))
} yield Add(Add(a, b), cd)
assert(transformed == expected)
}

test("multiTransformDownWithContinuation doesn't stop generating alternatives of descendants " +
"when non-leaf is transformed but the itself is not in the alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDownWithContinuation {
case Add(StringLiteral("a"), StringLiteral("b"), _) =>
Stream(Literal(11), Literal(12), Literal(21), Literal(22)) -> true
case StringLiteral("a") => Stream(Literal(1), Literal(2)) -> true
case StringLiteral("b") => Stream(Literal(10), Literal(20)) -> true
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
Stream(Literal(100), Literal(200)) -> true
}
val expected = for {
cd <- Seq(Literal(100), Literal(200))
ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++
(for {
b <- Seq(Literal(10), Literal(20))
a <- Seq(Literal(1), Literal(2))
} yield Add(a, b))
} yield Add(ab, cd)
assert(transformed == expected)
}

test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " +
"transformed and itself is in the alternatives") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
case StringLiteral("a") => Stream(Literal(1), Literal(2))
case StringLiteral("b") => Stream(Literal(10), Literal(20))
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200))
}
val expected = for {
cd <- Seq(Literal(100), Literal(200))
ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++
(for {
b <- Seq(Literal(10), Literal(20))
a <- Seq(Literal(1), Literal(2))
} yield Add(a, b))
} yield Add(ab, cd)
assert(transformed == expected)
}

test("multiTransformDown can prune") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Stream.empty
}
assert(transformed.isEmpty)

val transformed2 = e.multiTransformDown {
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
}
assert(transformed2.isEmpty)
}
}