Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
this(toBoundExprs(expressions, inputSchema))

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

/**
* Creates a code gen ordering for sorting this schema, in ascending order.
Expand Down Expand Up @@ -188,7 +188,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
extends Ordering[InternalRow] with KryoSerializable {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(toBoundExprs(ordering, inputSchema))

@transient
private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

private def createCodeForStruct(
ctx: CodegenContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
toBoundExprs(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(toBoundExprs(ordering, inputSchema))

def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ package object expressions {
/**
* A helper function to bind given expressions to an input schema.
*/
def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
def toBoundExprs[A <: Expression](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to minimize the chance that future changes suffer from the same issue. In order to do that we should provide API in a logical place, it does not make a whole lot of sense to me that I need to look in package.scala to find a more performant version of BindReferences.bindReference(..) for a seq. Can we move this function to BindReference and name it bindReferences?

exprs: Seq[A],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent

inputSchema: AttributeSeq): Seq[A] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ case class HashAggregateExec(
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val boundUpdateExpr = toBoundExprs(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down Expand Up @@ -825,7 +825,7 @@ case class HashAggregateExec(

val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand All @@ -849,7 +849,7 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = toBoundExprs(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
val lkeys = toBoundExprs(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = toBoundExprs(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ case class SortMergeJoinExec(
input: Seq[Attribute]): Seq[ExprCode] = {
ctx.INPUT_ROW = row
ctx.currentVars = null
keys.map(BindReferences.bindReference(_, input).genCode(ctx))
toBoundExprs(keys, input).map(_.genCode(ctx))
}

private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
Expand Down