@@ -25,10 +25,14 @@ import org.apache.spark.sql.catalyst.expressions.Expression
2525import org .apache .spark .sql .catalyst .expressions .aggregate .TypedImperativeAggregate
2626import org .apache .spark .sql .catalyst .trees .TernaryLike
2727import org .apache .spark .sql .internal .SQLConf
28+ import org .apache .spark .sql .types ._
2829import org .apache .spark .sql .types .DataType
2930import org .apache .spark .task .TaskResources
31+ import org .apache .spark .unsafe .types .UTF8String
3032import org .apache .spark .util .sketch .BloomFilter
3133
34+ import java .io .Serializable
35+
3236/**
3337 * Velox's bloom-filter implementation uses different algorithms internally comparing to vanilla
3438 * Spark so produces different intermediate aggregate data. Thus we use different filter function /
@@ -61,6 +65,15 @@ case class VeloxBloomFilterAggregate(
6165 .toLong
6266 )
6367
68+ // Mark as lazy so that `updater` is not evaluated during tree transformation.
69+ private lazy val updater : BloomFilterUpdater = child.dataType match {
70+ case LongType => LongUpdater
71+ case IntegerType => IntUpdater
72+ case ShortType => ShortUpdater
73+ case ByteType => ByteUpdater
74+ case _ : StringType => BinaryUpdater
75+ }
76+
6477 override def first : Expression = child
6578
6679 override def second : Expression = estimatedNumItemsExpression
@@ -97,7 +110,7 @@ case class VeloxBloomFilterAggregate(
97110 if (value == null ) {
98111 return buffer
99112 }
100- buffer.putLong( value. asInstanceOf [ Long ] )
113+ updater.update(buffer, value)
101114 buffer
102115 }
103116
@@ -128,3 +141,33 @@ case class VeloxBloomFilterAggregate(
128141 copy(inputAggBufferOffset = newOffset)
129142
130143}
144+
145+ // see https://github.com/apache/spark/pull/42414
146+ private trait BloomFilterUpdater {
147+ def update (bf : BloomFilter , v : Any ): Boolean
148+ }
149+
150+ private object LongUpdater extends BloomFilterUpdater with Serializable {
151+ override def update (bf : BloomFilter , v : Any ): Boolean =
152+ bf.putLong(v.asInstanceOf [Long ])
153+ }
154+
155+ private object IntUpdater extends BloomFilterUpdater with Serializable {
156+ override def update (bf : BloomFilter , v : Any ): Boolean =
157+ bf.putLong(v.asInstanceOf [Int ])
158+ }
159+
160+ private object ShortUpdater extends BloomFilterUpdater with Serializable {
161+ override def update (bf : BloomFilter , v : Any ): Boolean =
162+ bf.putLong(v.asInstanceOf [Short ])
163+ }
164+
165+ private object ByteUpdater extends BloomFilterUpdater with Serializable {
166+ override def update (bf : BloomFilter , v : Any ): Boolean =
167+ bf.putLong(v.asInstanceOf [Byte ])
168+ }
169+
170+ private object BinaryUpdater extends BloomFilterUpdater with Serializable {
171+ override def update (bf : BloomFilter , v : Any ): Boolean =
172+ bf.putBinary(v.asInstanceOf [UTF8String ].getBytes)
173+ }
0 commit comments