Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,21 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = f(arg1.asInstanceOf[BaseType])
val newChild2 = f(arg2.asInstanceOf[BaseType])
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can call arg1.asInstanceOf[BaseType] here, to avoid this change

}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
(newChild1.asInstanceOf[BaseType], newChild2.asInstanceOf[BaseType])
} else {
tuple
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Expression with
override lazy val resolved = true
}

case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
notsons: Seq[(Expression, Expression)]) extends Expression with Unevaluable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: nonSons

override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}

case class JsonTestTreeNode(arg: Any) extends LeafNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
}
Expand Down Expand Up @@ -146,6 +154,23 @@ class TreeNodeSuite extends SparkFunSuite {
assert(actual === Dummy(None))
}

test("mapChildren should only works on children") {
val children = Seq((Literal(1), Literal(2)))
val notChildren = Seq((Literal(3), Literal(4)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: nonChildren

val before = SeqTupleExpression(children, notChildren)
val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), notChildren)

var actual = before transformDown toZero
assert(actual === expect)

actual = before transformUp toZero
assert(actual === expect)

actual = before transform toZero
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think testing transform is good enough

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we can testing mapChildren directly

assert(actual === expect)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to use .equals? Although it will call Object's .equals which is ==.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand it wrongly. === compared to == could provide more information about the error. You can see the follow examples:

scala> assert(1 == 2)
java.lang.AssertionError: assertion failed
  at scala.Predef$.assert(Predef.scala:156)
  ... 32 elided

scala> assert(1 === 2)
<console>:12: error: value === is not a member of Int
       assert(1 === 2)

And also you can check it in https://stackoverflow.com/questions/10489548/what-is-the-triple-equals-operator-in-scala-koans

}

test("preserves origin") {
CurrentOrigin.setPosition(1, 1)
val add = Add(Literal(1), Literal(1))
Expand Down