Skip to content

Commit 0857f6d

Browse files
committed
[Fix] Resolve an issue introduced by apache/spark#42414, as identified in apache/spark#53038.
1 parent 29816f1 commit 0857f6d

1 file changed

Lines changed: 44 additions & 1 deletion

File tree

backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxBloomFilterAggregate.scala

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ import org.apache.spark.sql.catalyst.expressions.Expression
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
2626
import org.apache.spark.sql.catalyst.trees.TernaryLike
2727
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types._
2829
import org.apache.spark.sql.types.DataType
2930
import org.apache.spark.task.TaskResources
31+
import org.apache.spark.unsafe.types.UTF8String
3032
import org.apache.spark.util.sketch.BloomFilter
3133

34+
import java.io.Serializable
35+
3236
/**
3337
* Velox's bloom-filter implementation uses different algorithms internally comparing to vanilla
3438
* Spark so produces different intermediate aggregate data. Thus we use different filter function /
@@ -61,6 +65,15 @@ case class VeloxBloomFilterAggregate(
6165
.toLong
6266
)
6367

68+
// Mark as lazy so that `updater` is not evaluated during tree transformation.
69+
private lazy val updater: BloomFilterUpdater = child.dataType match {
70+
case LongType => LongUpdater
71+
case IntegerType => IntUpdater
72+
case ShortType => ShortUpdater
73+
case ByteType => ByteUpdater
74+
case _: StringType => BinaryUpdater
75+
}
76+
6477
override def first: Expression = child
6578

6679
override def second: Expression = estimatedNumItemsExpression
@@ -97,7 +110,7 @@ case class VeloxBloomFilterAggregate(
97110
if (value == null) {
98111
return buffer
99112
}
100-
buffer.putLong(value.asInstanceOf[Long])
113+
updater.update(buffer, value)
101114
buffer
102115
}
103116

@@ -128,3 +141,33 @@ case class VeloxBloomFilterAggregate(
128141
copy(inputAggBufferOffset = newOffset)
129142

130143
}
144+
145+
// see https://github.com/apache/spark/pull/42414
146+
private trait BloomFilterUpdater {
147+
def update(bf: BloomFilter, v: Any): Boolean
148+
}
149+
150+
private object LongUpdater extends BloomFilterUpdater with Serializable {
151+
override def update(bf: BloomFilter, v: Any): Boolean =
152+
bf.putLong(v.asInstanceOf[Long])
153+
}
154+
155+
private object IntUpdater extends BloomFilterUpdater with Serializable {
156+
override def update(bf: BloomFilter, v: Any): Boolean =
157+
bf.putLong(v.asInstanceOf[Int])
158+
}
159+
160+
private object ShortUpdater extends BloomFilterUpdater with Serializable {
161+
override def update(bf: BloomFilter, v: Any): Boolean =
162+
bf.putLong(v.asInstanceOf[Short])
163+
}
164+
165+
private object ByteUpdater extends BloomFilterUpdater with Serializable {
166+
override def update(bf: BloomFilter, v: Any): Boolean =
167+
bf.putLong(v.asInstanceOf[Byte])
168+
}
169+
170+
private object BinaryUpdater extends BloomFilterUpdater with Serializable {
171+
override def update(bf: BloomFilter, v: Any): Boolean =
172+
bf.putBinary(v.asInstanceOf[UTF8String].getBytes)
173+
}

0 commit comments

Comments
 (0)