diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index fd40741cfb5f..d4a19132405b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -226,23 +226,18 @@ object DataType { } } + private val NoNameCheck = 0 + private val CaseSensitiveNameCheck = 1 + private val CaseInsensitiveNameCheck = 2 + private val NoNullabilityCheck = 0 + private val NullabilityCheck = 1 + private val CompatibleNullabilityCheck = 2 + /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { - (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => - equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => - equalsIgnoreNullability(leftKeyType, rightKeyType) && - equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => - leftFields.length == rightFields.length && - leftFields.zip(rightFields).forall { case (l, r) => - l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) - } - case (l, r) => l == r - } + equalsDataTypes(left, right, CaseSensitiveNameCheck, NoNullabilityCheck) } /** @@ -260,25 +255,7 @@ object DataType { * of `fromField.nullable` and `toField.nullable` are false. */ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => - (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey) && - equalsIgnoreCompatibleNullability(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } + equalsDataTypes(from, to, CaseSensitiveNameCheck, CompatibleNullabilityCheck) } /** @@ -286,23 +263,7 @@ object DataType { * sensitivity of field names in StructType. */ private[sql] def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, _), ArrayType(toElement, _)) => - equalsIgnoreCaseAndNullability(fromElement, toElement) - - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => - equalsIgnoreCaseAndNullability(fromKey, toKey) && - equalsIgnoreCaseAndNullability(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (l, r) => - l.name.equalsIgnoreCase(r.name) && - equalsIgnoreCaseAndNullability(l.dataType, r.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } + equalsDataTypes(from, to, CaseInsensitiveNameCheck, NoNullabilityCheck) } /** @@ -315,25 +276,82 @@ object DataType { from: DataType, to: DataType, ignoreNullability: Boolean = false): Boolean = { - (from, to) match { + if (ignoreNullability) { + equalsDataTypes(from, to, NoNameCheck, NoNullabilityCheck) + } else { + equalsDataTypes(from, to, NoNameCheck, NullabilityCheck) + } + } + + /** Given the fieldNames compare for equality based on nameCheckType */ + private def isSameFieldName(left: String, right: String, nameCheckType: Int): Boolean = { + nameCheckType match { + case NoNameCheck => true + case CaseSensitiveNameCheck => left == right + case CaseInsensitiveNameCheck => left.toLowerCase == right.toLowerCase + } + } + + /** Given the nullability of two datatypes compare for equality based on nullabilityCheckType */ + private def isSameNullability( + leftNullability: Boolean, + rightNullability: Boolean, + nullabilityCheckType: Int): Boolean = { + nullabilityCheckType match { + case NoNullabilityCheck => true + case NullabilityCheck => leftNullability == rightNullability + case CompatibleNullabilityCheck => rightNullability || !leftNullability + } + } + + /** + * Compare two dataTypes based on - + * nameCheckType - (NoNameCheck, CaseSensitiveNameCheck, CaseInsensitiveNameCheck) + * nullabilityCheckType - (NoNullabilityCheck, NullabilityCheck, CompatibleNullabilityCheck) + * @param left + * @param right + * @param nameCheckType + * @param nullabilityCheckType + * @return + */ + private def equalsDataTypes( + left: DataType, + right: DataType, + nameCheckType: Int, + nullabilityCheckType: Int + ): Boolean = { + (left, right) match { case (left: ArrayType, right: ArrayType) => - equalsStructurally(left.elementType, right.elementType) && - (ignoreNullability || left.containsNull == right.containsNull) + val sameNullability = isSameNullability(left.containsNull, right.containsNull, + nullabilityCheckType) + val sameType = equalsDataTypes(left.elementType, right.elementType, + nameCheckType, nullabilityCheckType) + sameNullability && sameType case (left: MapType, right: MapType) => - equalsStructurally(left.keyType, right.keyType) && - equalsStructurally(left.valueType, right.valueType) && - (ignoreNullability || left.valueContainsNull == right.valueContainsNull) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields) - .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType) && - (ignoreNullability || l.nullable == r.nullable) - } - - case (fromDataType, toDataType) => fromDataType == toDataType + val sameNullability = isSameNullability(left.valueContainsNull, right.valueContainsNull, + nullabilityCheckType) + val sameKeyType = equalsDataTypes(left.keyType, right.keyType, + nameCheckType, nullabilityCheckType) + val sameValueType = equalsDataTypes(left.valueType, right.valueType, + nameCheckType, nullabilityCheckType) + sameNullability && sameKeyType && sameValueType + + case (StructType(leftFields), StructType(rightFields)) => + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (lf, rf) => + val sameFieldName = isSameFieldName(lf.name, rf.name, nameCheckType) + val sameNullability = isSameNullability(lf.nullable, rf.nullable, nullabilityCheckType) + val sameType = equalsDataTypes(lf.dataType, rf.dataType, + nameCheckType, nullabilityCheckType) + + sameFieldName && sameNullability && sameType + } + + case (leftDataType, rightDataType) => leftDataType == rightDataType } } + + + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 5a86f4055dce..43b67d072c23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -310,6 +310,88 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 12) checkDefaultSize(structType, 20) + def checkEqualsIgnoreNullability( + from: DataType, + to: DataType, + expected: Boolean): Unit = { + val testName = + s"equalsIgnoreNullability: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsIgnoreNullability(from, to) === expected) + } + } + + checkEqualsIgnoreNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(DoubleType, containsNull = true), + expected = true) + checkEqualsIgnoreNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(DoubleType, containsNull = false), nullable = false):: Nil + ), + expected = true) + checkEqualsIgnoreNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("c", ArrayType(DoubleType, containsNull = false), nullable = false):: Nil + ), + expected = false) + checkEqualsIgnoreNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("B", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil + ), + expected = false) + checkEqualsIgnoreNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(MapType(StringType, StringType, valueContainsNull = false), + containsNull = false), nullable = true):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(MapType(StringType, StringType, valueContainsNull = false), + containsNull = false), nullable = true):: Nil + ), + expected = true) + + def checkEqualsIgnoreCaseAndNullability( + from: DataType, + to: DataType, + expected: Boolean): Unit = { + val testName = + s"equalsIgnoreCaseAndNullability: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsIgnoreCaseAndNullability(from, to) === expected) + } + } + + checkEqualsIgnoreCaseAndNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false), + containsNull = false), nullable = true):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("B", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false), + containsNull = false), nullable = true):: Nil + ), + expected = true) + def checkEqualsIgnoreCompatibleNullability( from: DataType, to: DataType, @@ -392,6 +474,30 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", StringType, nullable = false) :: StructField("b", StringType, nullable = false) :: Nil), expected = false) + checkEqualsIgnoreCompatibleNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false), + containsNull = false), nullable = true):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("B", ArrayType(MapType(IntegerType, StringType, valueContainsNull = true), + containsNull = false), nullable = false):: Nil + ), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false), + containsNull = false), nullable = false):: Nil + ), + to = StructType( + StructField("a", DoubleType, nullable = false):: + StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = true), + containsNull = false), nullable = true):: Nil + ), + expected = true) def checkCatalogString(dt: DataType): Unit = { test(s"catalogString: $dt") {