diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cf9796ef1948..17c683cc8ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1464,17 +1464,29 @@ case class ArrayContains(left: Expression, right: Expression) nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(right.dataType, value, getValue)}) { - ${ev.isNull} = false; - ${ev.value} = true; - break; - } + val loopBodyCode = if (nullable) { + s""" + |if ($arr.isNullAt($i)) { + | ${ev.isNull} = true; + |} else if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin + } else { + s""" + |if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${ev.value} = true; + | break; + |} + """.stripMargin } - """ + s""" + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $loopBodyCode + |} + """.stripMargin }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7b345aabd19c..a9fc3e9c7b37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -383,10 +383,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a3 = Literal.create(null, ArrayType(StringType)) val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( StructField("a", IntegerType, true))))) + // Explicitly mark the array type not nullable (spark-25308) + val a5 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) + checkEvaluation(ArrayContains(a5, Literal(1)), true) checkEvaluation(ArrayContains(a1, Literal("")), true) checkEvaluation(ArrayContains(a1, Literal("a")), null)