Skip to content

Commit b4adb0f

Browse files
committed
Merge pull request #14 from marmbrus/castingAndTypes
Make casting semantics more like Hive's
2 parents b21f803 + b2a1ec5 commit b4adb0f

File tree

9 files changed

+242
-94
lines changed

9 files changed

+242
-94
lines changed

src/main/scala/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
3636
Batch("Aggregation", Once,
3737
GlobalAggregates),
3838
Batch("Type Coersion", fixedPoint,
39+
StringToIntegralCasts,
40+
BooleanCasts,
3941
PromoteNumericTypes,
4042
PromoteStrings,
4143
ConvertNaNs,

src/main/scala/catalyst/analysis/typeCoercion.scala

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,30 @@ object BooleanComparisons extends Rule[LogicalPlan] {
151151
}
152152
}
153153

154+
/**
155+
* Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since
156+
* the JVM does not consider Booleans to be numeric types.
157+
*/
158+
object BooleanCasts extends Rule[LogicalPlan] {
159+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
160+
case Cast(e, BooleanType) => Not(Equals(e, Literal(0)))
161+
case Cast(e, dataType) if e.dataType == BooleanType =>
162+
Cast(If(e, Literal(1), Literal(0)), dataType)
163+
}
164+
}
165+
166+
/**
167+
* When encountering a cast from a string representing a valid fractional number to an integral type
168+
* the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the
169+
* truncated version of this number.
170+
*/
171+
object StringToIntegralCasts extends Rule[LogicalPlan] {
172+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
173+
case Cast(e @ StringType(), t: IntegralType) =>
174+
Cast(Cast(e, DecimalType), t)
175+
}
176+
}
177+
154178
/**
155179
* This ensure that the types for various functions are as expected. Most of these rules are
156180
* actually Hive specific.
@@ -162,9 +186,9 @@ object FunctionArgumentConversion extends Rule[LogicalPlan] {
162186
case e if !e.childrenResolved => e
163187

164188
// Promote SUM to largest types to prevent overflows.
165-
// TODO: This is enough to make most of the tests pass, but we really need a full set of our own
166-
// to really ensure compatibility.
167-
case Sum(e) if e.dataType == IntegerType => Sum(Cast(e, LongType))
189+
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
190+
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
191+
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))
168192

169193
}
170194
}

src/main/scala/catalyst/execution/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ abstract class HiveUdf extends Expression with ImplementedUdf with Logging {
8787
case l: LongWritable => l.get
8888
case d: DoubleWritable => d.get()
8989
case d: org.apache.hadoop.hive.serde2.io.DoubleWritable => d.get
90+
case s: org.apache.hadoop.hive.serde2.io.ShortWritable => s.get
9091
case b: BooleanWritable => b.get()
92+
case b: org.apache.hadoop.hive.serde2.io.ByteWritable => b.get
9193
case list: java.util.List[_] => list.map(unwrap)
9294
case p: java.lang.Short => p
9395
case p: java.lang.Long => p

src/main/scala/catalyst/execution/SharkInstance.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ abstract class SharkInstance extends Logging {
135135
s"""== Logical Plan ==
136136
|${stringOrError(analyzed)}
137137
|== Physical Plan ==
138-
|${stringOrError(sharkPlan)}
138+
|${stringOrError(executedPlan)}
139139
""".stripMargin.trim
140140
}
141141

src/main/scala/catalyst/expressions/Evaluate.scala

Lines changed: 40 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -25,109 +25,73 @@ object Evaluate extends Logging {
2525
null
2626
else
2727
e.dataType match {
28-
case IntegerType =>
29-
f.asInstanceOf[(Numeric[Int], Int) => Int](
30-
implicitly[Numeric[Int]], eval(e).asInstanceOf[Int])
31-
case DoubleType =>
32-
f.asInstanceOf[(Numeric[Double], Double) => Double](
33-
implicitly[Numeric[Double]], eval(e).asInstanceOf[Double])
34-
case LongType =>
35-
f.asInstanceOf[(Numeric[Long], Long) => Long](
36-
implicitly[Numeric[Long]], eval(e).asInstanceOf[Long])
37-
case FloatType =>
38-
f.asInstanceOf[(Numeric[Float], Float) => Float](
39-
implicitly[Numeric[Float]], eval(e).asInstanceOf[Float])
40-
case ByteType =>
41-
f.asInstanceOf[(Numeric[Byte], Byte) => Byte](
42-
implicitly[Numeric[Byte]], eval(e).asInstanceOf[Byte])
43-
case ShortType =>
44-
f.asInstanceOf[(Numeric[Short], Short) => Short](
45-
implicitly[Numeric[Short]], eval(e).asInstanceOf[Short])
28+
case n: NumericType =>
29+
val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType]
30+
castedFunction(n.numeric, eval(e).asInstanceOf[n.JvmType])
4631
case other => sys.error(s"Type $other does not support numeric operations")
4732
}
4833
}
4934

5035
@inline
5136
def n2(e1: Expression, e2: Expression, f: ((Numeric[Any], Any, Any) => Any)): Any = {
5237
if (e1.dataType != e2.dataType)
53-
throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
38+
throw new OptimizationException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")
5439

5540
val evalE1 = eval(e1)
5641
val evalE2 = eval(e2)
5742
if (evalE1 == null || evalE2 == null)
5843
null
5944
else
6045
e1.dataType match {
61-
case IntegerType =>
62-
f.asInstanceOf[(Numeric[Int], Int, Int) => Int](
63-
implicitly[Numeric[Int]], evalE1.asInstanceOf[Int], evalE2.asInstanceOf[Int])
64-
case DoubleType =>
65-
f.asInstanceOf[(Numeric[Double], Double, Double) => Double](
66-
implicitly[Numeric[Double]], evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
67-
case LongType =>
68-
f.asInstanceOf[(Numeric[Long], Long, Long) => Long](
69-
implicitly[Numeric[Long]], evalE1.asInstanceOf[Long], evalE2.asInstanceOf[Long])
70-
case FloatType =>
71-
f.asInstanceOf[(Numeric[Float], Float, Float) => Float](
72-
implicitly[Numeric[Float]], evalE1.asInstanceOf[Float], evalE2.asInstanceOf[Float])
73-
case ByteType =>
74-
f.asInstanceOf[(Numeric[Byte], Byte, Byte) => Byte](
75-
implicitly[Numeric[Byte]], evalE1.asInstanceOf[Byte], evalE2.asInstanceOf[Byte])
76-
case ShortType =>
77-
f.asInstanceOf[(Numeric[Short], Short, Short) => Short](
78-
implicitly[Numeric[Short]], evalE1.asInstanceOf[Short], evalE2.asInstanceOf[Short])
46+
case n: NumericType =>
47+
f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => Int](
48+
n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType])
7949
case other => sys.error(s"Type $other does not support numeric operations")
8050
}
8151
}
8252

8353
@inline
8454
def f2(e1: Expression, e2: Expression, f: ((Fractional[Any], Any, Any) => Any)): Any = {
8555
if (e1.dataType != e2.dataType)
86-
throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
56+
throw new OptimizationException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")
8757

8858
val evalE1 = eval(e1)
8959
val evalE2 = eval(e2)
9060
if (evalE1 == null || evalE2 == null)
9161
null
9262
else
9363
e1.dataType match {
94-
case DoubleType =>
95-
f.asInstanceOf[(Fractional[Double], Double, Double) => Double](
96-
implicitly[Fractional[Double]], evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
97-
case FloatType =>
98-
f.asInstanceOf[(Fractional[Float], Float, Float) => Float](
99-
implicitly[Fractional[Float]], evalE1.asInstanceOf[Float], evalE2.asInstanceOf[Float])
64+
case f: FractionalType =>
65+
f.asInstanceOf[(Fractional[f.JvmType], f.JvmType, f.JvmType) => f.JvmType](
66+
f.fractional, evalE1.asInstanceOf[f.JvmType], evalE2.asInstanceOf[f.JvmType])
10067
case other => sys.error(s"Type $other does not support fractional operations")
10168
}
10269
}
10370

10471
@inline
10572
def i2(e1: Expression, e2: Expression, f: ((Integral[Any], Any, Any) => Any)): Any = {
106-
if (e1.dataType != e2.dataType) throw new OptimizationException(e, s"Data types do not match ${e1.dataType} != ${e2.dataType}")
73+
if (e1.dataType != e2.dataType)
74+
throw new OptimizationException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")
10775
val evalE1 = eval(e1)
10876
val evalE2 = eval(e2)
10977
if (evalE1 == null || evalE2 == null)
11078
null
11179
else
11280
e1.dataType match {
113-
case IntegerType =>
114-
f.asInstanceOf[(Integral[Int], Int, Int) => Int](
115-
implicitly[Integral[Int]], evalE1.asInstanceOf[Int], evalE2.asInstanceOf[Int])
116-
case LongType =>
117-
f.asInstanceOf[(Integral[Long], Long, Long) => Long](
118-
implicitly[Integral[Long]], evalE1.asInstanceOf[Long], evalE2.asInstanceOf[Long])
119-
case ByteType =>
120-
f.asInstanceOf[(Integral[Byte], Byte, Byte) => Byte](
121-
implicitly[Integral[Byte]], evalE1.asInstanceOf[Byte], evalE2.asInstanceOf[Byte])
122-
case ShortType =>
123-
f.asInstanceOf[(Integral[Short], Short, Short) => Short](
124-
implicitly[Integral[Short]], evalE1.asInstanceOf[Short], evalE2.asInstanceOf[Short])
81+
case i: IntegralType =>
82+
f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
83+
i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
12584
case other => sys.error(s"Type $other does not support numeric operations")
12685
}
12786
}
12887

129-
@inline def castOrNull[A](f: => A) =
130-
try f catch { case _: java.lang.NumberFormatException => null }
88+
@inline def castOrNull[A](e: Expression, f: String => A) =
89+
try {
90+
eval(e) match {
91+
case null => null
92+
case s: String => f(s)
93+
}
94+
} catch { case _: java.lang.NumberFormatException => null }
13195

13296
val result = e match {
13397
case Literal(v, _) => v
@@ -142,13 +106,16 @@ object Evaluate extends Logging {
142106
case Add(l, r) => n2(l,r, _.plus(_, _))
143107
case Subtract(l, r) => n2(l,r, _.minus(_, _))
144108
case Multiply(l, r) => n2(l,r, _.times(_, _))
145-
// Divide & remainder implementation are different for fractional and integral dataTypes.
146-
case Divide(l, r) if (l.dataType == DoubleType || l.dataType == FloatType) => f2(l,r, _.div(_, _))
147-
case Divide(l, r) => i2(l,r, _.quot(_, _))
109+
// Divide implementation are different for fractional and integral dataTypes.
110+
case Divide(l @ FractionalType(), r) => f2(l,r, _.div(_, _))
111+
case Divide(l @ IntegralType(), r) => i2(l,r, _.quot(_, _))
148112
// Remainder is only allowed on Integral types.
149113
case Remainder(l, r) => i2(l,r, _.rem(_, _))
150114
case UnaryMinus(child) => n1(child, _.negate(_))
151115

116+
/* Control Flow */
117+
case If(e, t, f) => if (eval(e).asInstanceOf[Boolean]) eval(t) else eval(f)
118+
152119
/* Comparisons */
153120
case Equals(l, r) =>
154121
val left = eval(l)
@@ -197,16 +164,14 @@ object Evaluate extends Logging {
197164
}
198165

199166
// String => Numeric Types
200-
case Cast(e, IntegerType) if e.dataType == StringType =>
201-
eval(e) match {
202-
case null => null
203-
case s: String => castOrNull(s.toInt)
204-
}
205-
case Cast(e, DoubleType) if e.dataType == StringType =>
206-
eval(e) match {
207-
case null => null
208-
case s: String => castOrNull(s.toDouble)
209-
}
167+
case Cast(e @ StringType(), IntegerType) => castOrNull(e, _.toInt)
168+
case Cast(e @ StringType(), DoubleType) => castOrNull(e, _.toDouble)
169+
case Cast(e @ StringType(), FloatType) => castOrNull(e, _.toFloat)
170+
case Cast(e @ StringType(), LongType) => castOrNull(e, _.toLong)
171+
case Cast(e @ StringType(), ShortType) => castOrNull(e, _.toShort)
172+
case Cast(e @ StringType(), ByteType) => castOrNull(e, _.toByte)
173+
case Cast(e @ StringType(), DecimalType) => castOrNull(e, BigDecimal(_))
174+
210175
// Boolean conversions
211176
case Cast(e, ByteType) if e.dataType == BooleanType =>
212177
eval(e) match {
@@ -263,6 +228,9 @@ object Evaluate extends Logging {
263228
case implementedFunction: ImplementedUdf =>
264229
implementedFunction.evaluate(implementedFunction.children.map(eval))
265230

231+
case a: Attribute =>
232+
throw new OptimizationException(a,
233+
"Unable to evaluate unbound reference without access to the input schema.")
266234
case other => throw new OptimizationException(other, "evaluation not implemented")
267235
}
268236

src/main/scala/catalyst/expressions/predicates.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package catalyst
22
package expressions
33

44
import types._
5+
import catalyst.analysis.UnresolvedException
56

67
trait Predicate extends Expression {
78
self: Product =>
@@ -74,3 +75,19 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
7475
override def foldable = child.foldable
7576
def nullable = false
7677
}
78+
79+
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
80+
extends Expression {
81+
82+
def children = predicate :: trueValue :: falseValue :: Nil
83+
def nullable = trueValue.nullable || falseValue.nullable
84+
def references = children.flatMap(_.references).toSet
85+
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
86+
def dataType = {
87+
if (!resolved) {
88+
throw new UnresolvedException(
89+
this, s"Invalid types: ${trueValue.dataType}, ${falseValue.dataType}")
90+
}
91+
trueValue.dataType
92+
}
93+
}

src/main/scala/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ object Optimize extends RuleExecutor[LogicalPlan] {
1212
EliminateSubqueries) ::
1313
Batch("ConstantFolding", Once,
1414
ConstantFolding,
15-
BooleanSimplification
16-
) :: Nil
15+
BooleanSimplification,
16+
SimplifyCasts) :: Nil
1717
}
1818

1919
/**
@@ -68,4 +68,13 @@ object BooleanSimplification extends Rule[LogicalPlan] {
6868
}
6969
}
7070
}
71+
}
72+
73+
/**
74+
* Removes casts that are unnecessary because the input is already the correct type.
75+
*/
76+
object SimplifyCasts extends Rule[LogicalPlan] {
77+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
78+
case Cast(e, dataType) if e.dataType == dataType => e
79+
}
7180
}

0 commit comments

Comments
 (0)