diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 8994eeff92c7f..104ad98ca099f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -233,7 +233,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true + override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) { + val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() + child match { + case CreateArray(ar) if intOrdinal < ar.length => + ar(intOrdinal).nullable + case GetArrayStructFields(CreateArray(elements), field, _, _, _) + if intOrdinal < elements.length => + elements(intOrdinal).nullable || field.nullable + case _ => + true + } + } else { + true + } override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index dc60464815043..d8d65715281d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -59,6 +59,39 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } + test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") { + // CreateArray case + val a = AttributeReference("a", IntegerType, nullable = false)() + val b = AttributeReference("b", IntegerType, nullable = true)() + val array = CreateArray(a :: b :: Nil) + assert(!GetArrayItem(array, Literal(0)).nullable) + assert(GetArrayItem(array, Literal(1)).nullable) + assert(!GetArrayItem(array, Subtract(Literal(2), Literal(2))).nullable) + assert(GetArrayItem(array, AttributeReference("ordinal", IntegerType)()).nullable) + + // GetArrayStructFields case + val f1 = StructField("a", IntegerType, nullable = false) + val f2 = StructField("b", IntegerType, nullable = true) + val structType = StructType(f1 :: f2 :: Nil) + val c = AttributeReference("c", structType, nullable = false)() + val inputArray1 = CreateArray(c :: Nil) + val inputArray1ContainsNull = c.nullable + val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) + assert(!GetArrayItem(stArray1, Literal(0)).nullable) + val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) + assert(GetArrayItem(stArray2, Literal(0)).nullable) + + val d = AttributeReference("d", structType, nullable = true)() + val inputArray2 = CreateArray(c :: d :: Nil) + val inputArray2ContainsNull = c.nullable || d.nullable + val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) + assert(!GetArrayItem(stArray3, Literal(0)).nullable) + assert(GetArrayItem(stArray3, Literal(1)).nullable) + val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) + assert(GetArrayItem(stArray4, Literal(0)).nullable) + assert(GetArrayItem(stArray4, Literal(1)).nullable) + } + test("GetMapValue") { val typeM = MapType(StringType, StringType) val map = Literal.create(Map("a" -> "b"), typeM)