@@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
325325
326326 override def genCode (ctx : CodegenContext , ev : ExprCode ): String = {
327327 ev.isNull = " false"
328- val childrenHash = children.zipWithIndex.map {
329- case (child, dt) =>
330- val childGen = child.gen(ctx)
331- val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
332- s """
333- ${childGen.code}
334- if (! ${childGen.isNull}) {
335- ${childHash.code}
336- ${ev.value} = ${childHash.value};
337- }
338- """
328+ val childrenHash = children.map { child =>
329+ val childGen = child.gen(ctx)
330+ childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
331+ computeHash(childGen.value, child.dataType, ev.value, ctx)
332+ }
339333 }.mkString(" \n " )
334+
340335 s """
341336 int ${ev.value} = $seed;
342337 $childrenHash
343338 """
344339 }
345340
341+ private def generateNullCheck (nullable : Boolean , isNull : String )(execution : String ): String = {
342+ if (nullable) {
343+ s """
344+ if (! $isNull) {
345+ $execution
346+ }
347+ """
348+ } else {
349+ " \n " + execution
350+ }
351+ }
352+
353+ private def nullSafeElementHash (
354+ input : String ,
355+ index : String ,
356+ nullable : Boolean ,
357+ elementType : DataType ,
358+ result : String ,
359+ ctx : CodegenContext ): String = {
360+ val element = ctx.freshName(" element" )
361+
362+ generateNullCheck(nullable, s " $input.isNullAt( $index) " ) {
363+ s """
364+ final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
365+ ${computeHash(element, elementType, result, ctx)}
366+ """
367+ }
368+ }
369+
346370 private def computeHash (
347371 input : String ,
348372 dataType : DataType ,
349- seed : String ,
350- ctx : CodegenContext ): ExprCode = {
373+ result : String ,
374+ ctx : CodegenContext ): String = {
351375 val hasher = classOf [Murmur3_x86_32 ].getName
352- def hashInt (i : String ): ExprCode = inlineValue(s " $hasher.hashInt( $i, $seed) " )
353- def hashLong (l : String ): ExprCode = inlineValue(s " $hasher.hashLong( $l, $seed) " )
354- def inlineValue (v : String ): ExprCode = ExprCode (code = " " , isNull = " false" , value = v)
376+
377+ def hashInt (i : String ): String = s " $result = $hasher.hashInt( $i, $result); "
378+ def hashLong (l : String ): String = s " $result = $hasher.hashLong( $l, $result); "
379+ def hashBytes (b : String ): String =
380+ s " $result = $hasher.hashUnsafeBytes( $b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result); "
355381
356382 dataType match {
357- case NullType => inlineValue(seed)
383+ case NullType => " "
358384 case BooleanType => hashInt(s " $input ? 1 : 0 " )
359385 case ByteType | ShortType | IntegerType | DateType => hashInt(input)
360386 case LongType | TimestampType => hashLong(input)
@@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
365391 hashLong(s " $input.toUnscaledLong() " )
366392 } else {
367393 val bytes = ctx.freshName(" bytes" )
368- val code = s " byte[] $bytes = $input .toJavaBigDecimal().unscaledValue().toByteArray(); "
369- val offset = " Platform.BYTE_ARRAY_OFFSET "
370- val result = s " $hasher .hashUnsafeBytes( $ bytes, $offset , $bytes .length, $seed ) "
371- ExprCode (code, " false " , result)
394+ s "" "
395+ final byte[] $bytes = $input .toJavaBigDecimal().unscaledValue().toByteArray();
396+ ${hashBytes( bytes)}
397+ """
372398 }
373399 case CalendarIntervalType =>
374- val microsecondsHash = s " $hasher.hashLong( $input.microseconds, $seed) "
375- val monthsHash = s " $hasher.hashInt( $input.months, $microsecondsHash) "
376- inlineValue(monthsHash)
377- case BinaryType =>
378- val offset = " Platform.BYTE_ARRAY_OFFSET"
379- inlineValue(s " $hasher.hashUnsafeBytes( $input, $offset, $input.length, $seed) " )
400+ val microsecondsHash = s " $hasher.hashLong( $input.microseconds, $result) "
401+ s " $result = $hasher.hashInt( $input.months, $microsecondsHash); "
402+ case BinaryType => hashBytes(input)
380403 case StringType =>
381404 val baseObject = s " $input.getBaseObject() "
382405 val baseOffset = s " $input.getBaseOffset() "
383406 val numBytes = s " $input.numBytes() "
384- inlineValue( s " $hasher.hashUnsafeBytes( $baseObject, $baseOffset, $numBytes, $seed ) " )
407+ s " $result = $ hasher.hashUnsafeBytes( $baseObject, $baseOffset, $numBytes, $result ); "
385408
386- case ArrayType (et, _) =>
387- val result = ctx.freshName(" result" )
409+ case ArrayType (et, containsNull) =>
388410 val index = ctx.freshName(" index" )
389- val element = ctx.freshName(" element" )
390- val elementHash = computeHash(element, et, result, ctx)
391- val code =
392- s """
393- int $result = $seed;
394- for (int $index = 0; $index < $input.numElements(); $index++) {
395- if (! $input.isNullAt( $index)) {
396- final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
397- ${elementHash.code}
398- $result = ${elementHash.value};
399- }
400- }
401- """
402- ExprCode (code, " false" , result)
411+ s """
412+ for (int $index = 0; $index < $input.numElements(); $index++) {
413+ ${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
414+ }
415+ """
403416
404- case MapType (kt, vt, _) =>
405- val result = ctx.freshName(" result" )
417+ case MapType (kt, vt, valueContainsNull) =>
406418 val index = ctx.freshName(" index" )
407419 val keys = ctx.freshName(" keys" )
408420 val values = ctx.freshName(" values" )
409- val key = ctx.freshName(" key" )
410- val value = ctx.freshName(" value" )
411- val keyHash = computeHash(key, kt, result, ctx)
412- val valueHash = computeHash(value, vt, result, ctx)
413- val code =
414- s """
415- int $result = $seed;
416- final ArrayData $keys = $input.keyArray();
417- final ArrayData $values = $input.valueArray();
418- for (int $index = 0; $index < $input.numElements(); $index++) {
419- final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
420- ${keyHash.code}
421- $result = ${keyHash.value};
422- if (! $values.isNullAt( $index)) {
423- final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
424- ${valueHash.code}
425- $result = ${valueHash.value};
426- }
427- }
428- """
429- ExprCode (code, " false" , result)
421+ s """
422+ final ArrayData $keys = $input.keyArray();
423+ final ArrayData $values = $input.valueArray();
424+ for (int $index = 0; $index < $input.numElements(); $index++) {
425+ ${nullSafeElementHash(keys, index, false , kt, result, ctx)}
426+ ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
427+ }
428+ """
430429
431430 case StructType (fields) =>
432- val result = ctx.freshName(" result" )
433- val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
434- case (dt, index) =>
435- val field = ctx.freshName(" field" )
436- val fieldHash = computeHash(field, dt, result, ctx)
437- s """
438- if (! $input.isNullAt( $index)) {
439- final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
440- ${fieldHash.code}
441- $result = ${fieldHash.value};
442- }
443- """
431+ fields.zipWithIndex.map { case (field, index) =>
432+ nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
444433 }.mkString(" \n " )
445- val code =
446- s """
447- int $result = $seed;
448- $fieldsHash
449- """
450- ExprCode (code, " false" , result)
451434
452- case udt : UserDefinedType [_] => computeHash(input, udt.sqlType, seed , ctx)
435+ case udt : UserDefinedType [_] => computeHash(input, udt.sqlType, result , ctx)
453436 }
454437 }
455438}
0 commit comments