Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,13 @@ object BindReferences extends Logging {
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
}

/**
* A helper function to bind given expressions to an input schema.
*/
def bindReferences[A <: Expression](
expressions: Seq[A],
input: AttributeSeq): Seq[A] = {
expressions.map(BindReferences.bindReference(_, input))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp


Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
*/
class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(toBoundExprs(expressions, inputSchema))
this(bindReferences(expressions, inputSchema))

private[this] val buffer = new Array[Any](expressions.size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(expressions, inputSchema))

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
Expand Down Expand Up @@ -99,7 +100,7 @@ object MutableProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}

Expand Down Expand Up @@ -162,7 +163,7 @@ object UnsafeProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}

Expand Down Expand Up @@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
create(toBoundExprs(exprs, inputSchema))
create(bindReferences(exprs, inputSchema))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp

// MutableProjection is not accessible in Java
Expand All @@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand All @@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

/**
* Creates a code gen ordering for sorting this schema, in ascending order.
Expand Down Expand Up @@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
extends Ordering[InternalRow] with KryoSerializable {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(ordering, inputSchema))

@transient
private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
Expand All @@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

private def createCodeForStruct(
ctx: CodegenContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(ExpressionCanonicalizer.execute)

protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
bindReferences(in, inputSchema)

def generate(
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types._


Expand All @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {

def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
this(bindReferences(ordering, inputSchema))

def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,6 @@ package object expressions {
override def apply(row: InternalRow): InternalRow = row
}

/**
* A helper function to bind given expressions to an input schema.
*/
def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

/**
* Helper functions for working with `Seq[Attribute]`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ case class ExpandExec(
// Part 1: declare variables for each column
// If a column has the same value for all output rows, then we also generate its computation
// right after declaration. Otherwise its value is computed in the part 2.
lazy val attributeSeq: AttributeSeq = child.output
val outputColumns = output.indices.map { col =>
val firstExpr = projections.head(col)
if (sameOutput(col)) {
// This column is the same across all output rows. Just generate code for it here.
BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
Expand All @@ -170,7 +171,7 @@ case class ExpandExec(
var updateCode = ""
for (col <- exprs.indices) {
if (!sameOutput(col)) {
val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
updateCode +=
s"""
|${ev.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ abstract class AggregationIterator(
val expressionsLength = expressions.length
val functions = new Array[AggregateFunction](expressionsLength)
var i = 0
val inputAttributeSeq: AttributeSeq = inputAttributes
while (i < expressionsLength) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
Expand All @@ -86,7 +87,7 @@ abstract class AggregationIterator(
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
BindReferences.bindReference(func, inputAttributeSeq)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -199,15 +200,13 @@ case class HashAggregateExec(
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val aggResults = bindReferences(
functions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
}
val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
(resultVars, s"""
|$evaluateAggResults
|${evaluateVariables(resultVars)}
Expand Down Expand Up @@ -264,7 +263,7 @@ case class HashAggregateExec(
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down Expand Up @@ -456,16 +455,16 @@ case class HashAggregateExec(
val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val aggResults = bindReferences(
declFunctions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateBufferVars
Expand Down Expand Up @@ -494,9 +493,9 @@ case class HashAggregateExec(

ctx.currentVars = keyVars ++ resultBufferVars
val inputAttrs = resultExpressions.map(_.toAttribute)
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateResultBufferVars
Expand All @@ -506,9 +505,9 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
}
val eval = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
}
ctx.addNewFunction(funcName,
Expand Down Expand Up @@ -730,9 +729,9 @@ case class HashAggregateExec(
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
ctx, bindReferences[Expression](groupingExpressions, child.output))
val fastRowKeys = ctx.generateExpressions(
groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
bindReferences[Expression](groupingExpressions, child.output))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
Expand Down Expand Up @@ -825,7 +824,7 @@ case class HashAggregateExec(

val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand All @@ -849,7 +848,7 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
val exprs = bindReferences[Expression](projectList, child.output)
val resultVars = exprs.map(_.genCode(ctx))
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = requiredOrdering
.map(SortOrder(_, Ascending))
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
SortExec(
orderingExpr,
global = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
Expand Down Expand Up @@ -63,9 +64,8 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
Expand Down
Loading