Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def right: Expression = ordinal

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}

test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!GetArrayItem(array, Literal(0)).nullable)
assert(GetArrayItem(array, Literal(1)).nullable)
assert(!GetArrayItem(array, Subtract(Literal(2), Literal(2))).nullable)
assert(GetArrayItem(array, AttributeReference("ordinal", IntegerType)()).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val stArray1 = GetArrayStructFields(CreateArray(c :: Nil), f1, 0, 2, containsNull = f1.nullable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

containsNull should be c.nullable, according to the logic in ExtractValue.apply.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I'll update soon.

assert(!GetArrayItem(stArray1, Literal(0)).nullable)
val stArray2 = GetArrayStructFields(CreateArray(c :: Nil), f2, 1, 2, containsNull = f2.nullable)
assert(GetArrayItem(stArray2, Literal(0)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val stArray3 = GetArrayStructFields(CreateArray(c :: d :: Nil), f1, 0, 2,
containsNull = f1.nullable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert(!GetArrayItem(stArray3, Literal(0)).nullable)
assert(GetArrayItem(stArray3, Literal(1)).nullable)
val stArray4 = GetArrayStructFields(CreateArray(c :: d :: Nil), f2, 1, 2,
containsNull = f2.nullable)
assert(GetArrayItem(stArray4, Literal(0)).nullable)
assert(GetArrayItem(stArray4, Literal(1)).nullable)
}

test("GetMapValue") {
val typeM = MapType(StringType, StringType)
val map = Literal.create(Map("a" -> "b"), typeM)
Expand Down