Skip to content

Commit 067b788

Browse files
committed
[SPARK-16323] [SQL] Add IntegerDivide to avoid unnecessary cast
Before: ``` scala> spark.sql("select 6 div 3").explain(true) ... == Analyzed Logical Plan == CAST((6 / 3) AS BIGINT): bigint Project [cast((cast(6 as double) / cast(3 as double)) as bigint) AS CAST((6 / 3) AS BIGINT)#5L] +- OneRowRelation$ ... ``` After: ``` scala> spark.sql("select 6 div 3").explain(true) ... == Analyzed Logical Plan == (6 / 3): int Project [(6 / 3) AS (6 / 3)#11] +- OneRowRelation$ ... ```
1 parent 192d1f9 commit 067b788

7 files changed

Lines changed: 97 additions & 12 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ object FunctionRegistry {
232232
expression[Subtract]("-"),
233233
expression[Multiply]("*"),
234234
expression[Divide]("/"),
235+
expression[IntegerDivide]("div"),
235236
expression[Remainder]("%"),
236237

237238
// aggregate functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ package object dsl {
7171
def - (other: Expression): Expression = Subtract(expr, other)
7272
def * (other: Expression): Expression = Multiply(expr, other)
7373
def / (other: Expression): Expression = Divide(expr, other)
74+
def div (other: Expression): Expression = IntegerDivide(expr, other)
7475
def % (other: Expression): Expression = Remainder(expr, other)
7576
def & (other: Expression): Expression = BitwiseAnd(expr, other)
7677
def | (other: Expression): Expression = BitwiseOr(expr, other)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ case class Divide(left: Expression, right: Expression)
216216
override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
217217

218218
override def symbol: String = "/"
219-
override def decimalMethod: String = "$div"
220219
override def nullable: Boolean = true
221220

222221
private lazy val div: (Any, Any) => Any = dataType match {
@@ -284,6 +283,75 @@ case class Divide(left: Expression, right: Expression)
284283
}
285284
}
286285

286+
@ExpressionDescription(
287+
usage = "a _FUNC_ b - Divides a by b.",
288+
extended = "> SELECT 3 _FUNC_ 2;\n 1")
289+
case class IntegerDivide(left: Expression, right: Expression)
290+
extends BinaryArithmetic with NullIntolerant {
291+
292+
override def inputType: AbstractDataType = IntegralType
293+
294+
override def symbol: String = "/"
295+
override def decimalMethod: String = "$div"
296+
override def nullable: Boolean = true
297+
298+
private lazy val div: (Any, Any) => Any = dataType match {
299+
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]].quot
300+
}
301+
302+
override def eval(input: InternalRow): Any = {
303+
val input2 = right.eval(input)
304+
if (input2 == null || input2 == 0) {
305+
null
306+
} else {
307+
val input1 = left.eval(input)
308+
if (input1 == null) {
309+
null
310+
} else {
311+
div(input1, input2)
312+
}
313+
}
314+
}
315+
316+
/**
317+
* Special case handling due to division by 0 => null.
318+
*/
319+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
320+
val eval1 = left.genCode(ctx)
321+
val eval2 = right.genCode(ctx)
322+
val isZero = s"${eval2.value} == 0"
323+
val javaType = ctx.javaType(dataType)
324+
val divide = s"($javaType)(${eval1.value} $symbol ${eval2.value})"
325+
if (!left.nullable && !right.nullable) {
326+
ev.copy(code = s"""
327+
${eval2.code}
328+
boolean ${ev.isNull} = false;
329+
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
330+
if ($isZero) {
331+
${ev.isNull} = true;
332+
} else {
333+
${eval1.code}
334+
${ev.value} = $divide;
335+
}""")
336+
} else {
337+
ev.copy(code = s"""
338+
${eval2.code}
339+
boolean ${ev.isNull} = false;
340+
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
341+
if (${eval2.isNull} || $isZero) {
342+
${ev.isNull} = true;
343+
} else {
344+
${eval1.code}
345+
if (${eval1.isNull}) {
346+
${ev.isNull} = true;
347+
} else {
348+
${ev.value} = $divide;
349+
}
350+
}""")
351+
}
352+
}
353+
}
354+
287355
@ExpressionDescription(
288356
usage = "a _FUNC_ b - Returns the remainder when dividing a by b.")
289357
case class Remainder(left: Expression, right: Expression)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
957957
case SqlBaseParser.PERCENT =>
958958
Remainder(left, right)
959959
case SqlBaseParser.DIV =>
960-
Cast(Divide(left, right), LongType)
960+
IntegerDivide(left, right)
961961
case SqlBaseParser.PLUS =>
962962
Add(left, right)
963963
case SqlBaseParser.MINUS =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
140140

141141
// By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType.
142142
// TODO: in future release, we should add a IntegerDivide to support integral types.
143-
ignore("/ (Divide) for integral type") {
144-
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
145-
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
146-
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
147-
checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
148-
checkEvaluation(Divide(positiveShortLit, negativeShortLit), 0.toShort)
149-
checkEvaluation(Divide(positiveIntLit, negativeIntLit), 0)
150-
checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L)
143+
test("/ (Divide) for integral type") {
144+
checkEvaluation(IntegerDivide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
145+
checkEvaluation(IntegerDivide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
146+
checkEvaluation(IntegerDivide(Literal(1), Literal(2)), 0)
147+
checkEvaluation(IntegerDivide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
148+
checkEvaluation(IntegerDivide(positiveShortLit, negativeShortLit), 0.toShort)
149+
checkEvaluation(IntegerDivide(positiveIntLit, negativeIntLit), 0)
150+
checkEvaluation(IntegerDivide(positiveLongLit, negativeLongLit), 0L)
151151
}
152152

153153
test("% (Remainder)") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class ExpressionParserSuite extends PlanTest {
169169
// Simple operations
170170
assertEqual("a * b", 'a * 'b)
171171
assertEqual("a / b", 'a / 'b)
172-
assertEqual("a DIV b", ('a / 'b).cast(LongType))
172+
assertEqual("a DIV b", ('a div 'b))
173173
assertEqual("a % b", 'a % 'b)
174174
assertEqual("a + b", 'a + 'b)
175175
assertEqual("a - b", 'a - 'b)
@@ -180,7 +180,7 @@ class ExpressionParserSuite extends PlanTest {
180180
// Check precedences
181181
assertEqual(
182182
"a * t | b ^ c & d - e + f % g DIV h / i * k",
183-
'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k)))))
183+
'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g div 'h) / 'i * 'k)))))
184184
}
185185

186186
test("unary arithmetic expressions") {

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,21 @@ class Column(protected[sql] val expr: Expression) extends Logging {
731731
*/
732732
def / (other: Any): Column = withExpr { Divide(expr, lit(other).expr) }
733733

734+
/**
735+
* Integer Division this expression by another expression.
736+
* {{{
737+
* // Scala: The following divides a person's height by their weight.
738+
* people.select( people("height") div people("weight") )
739+
*
740+
* // Java:
741+
* people.select( people("height").div(people("weight")) );
742+
* }}}
743+
*
744+
* @group expr_ops
745+
* @since 2.1.0
746+
*/
747+
def div (other: Any): Column = withExpr { IntegerDivide(expr, lit(other).expr) }
748+
734749
/**
735750
* Division this expression by another expression.
736751
* {{{

0 commit comments

Comments
 (0)