Skip to content

Commit a9f685b

Browse files
maropucloud-fan
authored andcommitted
[SPARK-25734][SQL] Literal should have a value corresponding to dataType
## What changes were proposed in this pull request? `Literal.value` should have a value a value corresponding to `dataType`. This pr added code to verify it and fixed the existing tests to do so. ## How was this patch tested? Modified the existing tests. Closes #22724 from maropu/SPARK-25734. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e9af946 commit a9f685b

10 files changed

Lines changed: 92 additions & 39 deletions

File tree

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private[kafka010] object KafkaWriter extends Logging {
5252
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
5353
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
5454
} else {
55-
Literal(topic.get, StringType)
55+
Literal.create(topic.get, StringType)
5656
}
5757
).dataType match {
5858
case StringType => // good

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ import org.json4s.JsonAST._
4040
import org.apache.spark.sql.AnalysisException
4141
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
4242
import org.apache.spark.sql.catalyst.expressions.codegen._
43-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
43+
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
4444
import org.apache.spark.sql.types._
4545
import org.apache.spark.unsafe.types._
46+
import org.apache.spark.util.Utils
4647

4748
object Literal {
4849
val TrueLiteral: Literal = Literal(true, BooleanType)
@@ -196,6 +197,47 @@ object Literal {
196197
case other =>
197198
throw new RuntimeException(s"no default for type $dataType")
198199
}
200+
201+
private[expressions] def validateLiteralValue(value: Any, dataType: DataType): Unit = {
202+
def doValidate(v: Any, dataType: DataType): Boolean = dataType match {
203+
case _ if v == null => true
204+
case BooleanType => v.isInstanceOf[Boolean]
205+
case ByteType => v.isInstanceOf[Byte]
206+
case ShortType => v.isInstanceOf[Short]
207+
case IntegerType | DateType => v.isInstanceOf[Int]
208+
case LongType | TimestampType => v.isInstanceOf[Long]
209+
case FloatType => v.isInstanceOf[Float]
210+
case DoubleType => v.isInstanceOf[Double]
211+
case _: DecimalType => v.isInstanceOf[Decimal]
212+
case CalendarIntervalType => v.isInstanceOf[CalendarInterval]
213+
case BinaryType => v.isInstanceOf[Array[Byte]]
214+
case StringType => v.isInstanceOf[UTF8String]
215+
case st: StructType =>
216+
v.isInstanceOf[InternalRow] && {
217+
val row = v.asInstanceOf[InternalRow]
218+
st.fields.map(_.dataType).zipWithIndex.forall {
219+
case (dt, i) => doValidate(row.get(i, dt), dt)
220+
}
221+
}
222+
case at: ArrayType =>
223+
v.isInstanceOf[ArrayData] && {
224+
val ar = v.asInstanceOf[ArrayData]
225+
ar.numElements() == 0 || doValidate(ar.get(0, at.elementType), at.elementType)
226+
}
227+
case mt: MapType =>
228+
v.isInstanceOf[MapData] && {
229+
val map = v.asInstanceOf[MapData]
230+
doValidate(map.keyArray(), ArrayType(mt.keyType)) &&
231+
doValidate(map.valueArray(), ArrayType(mt.valueType))
232+
}
233+
case ObjectType(cls) => cls.isInstance(v)
234+
case udt: UserDefinedType[_] => doValidate(v, udt.sqlType)
235+
case _ => false
236+
}
237+
require(doValidate(value, dataType),
238+
s"Literal must have a corresponding value to ${dataType.catalogString}, " +
239+
s"but class ${Utils.getSimpleName(value.getClass)} found.")
240+
}
199241
}
200242

201243
/**
@@ -240,6 +282,8 @@ object DecimalLiteral {
240282
*/
241283
case class Literal (value: Any, dataType: DataType) extends LeafExpression {
242284

285+
Literal.validateLiteralValue(value, dataType)
286+
243287
override def foldable: Boolean = true
244288
override def nullable: Boolean = value == null
245289

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ class TypeCoercionSuite extends AnalysisTest {
742742
val nullLit = Literal.create(null, NullType)
743743
val floatNullLit = Literal.create(null, FloatType)
744744
val floatLit = Literal.create(1.0f, FloatType)
745-
val timestampLit = Literal.create("2017-04-12", TimestampType)
745+
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
746746
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
747747
val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis())))
748748
val strArrayLit = Literal(Array("c"))
@@ -793,11 +793,11 @@ class TypeCoercionSuite extends AnalysisTest {
793793
ruleTest(TypeCoercion.FunctionArgumentConversion,
794794
CreateArray(Literal(1.0)
795795
:: Literal(1)
796-
:: Literal.create(1.0, FloatType)
796+
:: Literal.create(1.0f, FloatType)
797797
:: Nil),
798798
CreateArray(Literal(1.0)
799799
:: Cast(Literal(1), DoubleType)
800-
:: Cast(Literal.create(1.0, FloatType), DoubleType)
800+
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
801801
:: Nil))
802802

803803
ruleTest(TypeCoercion.FunctionArgumentConversion,
@@ -834,23 +834,23 @@ class TypeCoercionSuite extends AnalysisTest {
834834
ruleTest(TypeCoercion.FunctionArgumentConversion,
835835
CreateMap(Literal(1)
836836
:: Literal("a")
837-
:: Literal.create(2.0, FloatType)
837+
:: Literal.create(2.0f, FloatType)
838838
:: Literal("b")
839839
:: Nil),
840840
CreateMap(Cast(Literal(1), FloatType)
841841
:: Literal("a")
842-
:: Literal.create(2.0, FloatType)
842+
:: Literal.create(2.0f, FloatType)
843843
:: Literal("b")
844844
:: Nil))
845845
ruleTest(TypeCoercion.FunctionArgumentConversion,
846846
CreateMap(Literal.create(null, DecimalType(5, 3))
847847
:: Literal("a")
848-
:: Literal.create(2.0, FloatType)
848+
:: Literal.create(2.0f, FloatType)
849849
:: Literal("b")
850850
:: Nil),
851851
CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
852852
:: Literal("a")
853-
:: Literal.create(2.0, FloatType).cast(DoubleType)
853+
:: Literal.create(2.0f, FloatType).cast(DoubleType)
854854
:: Literal("b")
855855
:: Nil))
856856
// type coercion for map values
@@ -895,11 +895,11 @@ class TypeCoercionSuite extends AnalysisTest {
895895
ruleTest(TypeCoercion.FunctionArgumentConversion,
896896
operator(Literal(1.0)
897897
:: Literal(1)
898-
:: Literal.create(1.0, FloatType)
898+
:: Literal.create(1.0f, FloatType)
899899
:: Nil),
900900
operator(Literal(1.0)
901901
:: Cast(Literal(1), DoubleType)
902-
:: Cast(Literal.create(1.0, FloatType), DoubleType)
902+
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
903903
:: Nil))
904904
ruleTest(TypeCoercion.FunctionArgumentConversion,
905905
operator(Literal(1L)
@@ -966,7 +966,7 @@ class TypeCoercionSuite extends AnalysisTest {
966966
val falseLit = Literal.create(false, BooleanType)
967967
val stringLit = Literal.create("c", StringType)
968968
val floatLit = Literal.create(1.0f, FloatType)
969-
val timestampLit = Literal.create("2017-04-12", TimestampType)
969+
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
970970
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
971971

972972
ruleTest(rule,
@@ -1016,14 +1016,16 @@ class TypeCoercionSuite extends AnalysisTest {
10161016
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
10171017
)
10181018
ruleTest(TypeCoercion.CaseWhenCoercion,
1019-
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
10201019
CaseWhen(Seq((Literal(true), Literal(1.2))),
1021-
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
1020+
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
1021+
CaseWhen(Seq((Literal(true), Literal(1.2))),
1022+
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DoubleType))
10221023
)
10231024
ruleTest(TypeCoercion.CaseWhenCoercion,
1024-
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
1025+
CaseWhen(Seq((Literal(true), Literal(100L))),
1026+
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
10251027
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
1026-
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
1028+
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DecimalType(22, 2)))
10271029
)
10281030
}
10291031

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
623623

624624
test("SPARK-21513: to_json support map[string, struct] to json") {
625625
val schema = MapType(StringType, StructType(StructField("a", IntegerType) :: Nil))
626-
val input = Literal.create(ArrayBasedMapData(Map("test" -> InternalRow(1))), schema)
626+
val input = Literal(
627+
ArrayBasedMapData(Map(UTF8String.fromString("test") -> InternalRow(1))), schema)
627628
checkEvaluation(
628629
StructsToJson(Map.empty, input),
629630
"""{"test":{"a":1}}"""
@@ -633,7 +634,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
633634
test("SPARK-21513: to_json support map[struct, struct] to json") {
634635
val schema = MapType(StructType(StructField("a", IntegerType) :: Nil),
635636
StructType(StructField("b", IntegerType) :: Nil))
636-
val input = Literal.create(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
637+
val input = Literal(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
637638
checkEvaluation(
638639
StructsToJson(Map.empty, input),
639640
"""{"[1]":{"b":2}}"""
@@ -642,7 +643,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
642643

643644
test("SPARK-21513: to_json support map[string, integer] to json") {
644645
val schema = MapType(StringType, IntegerType)
645-
val input = Literal.create(ArrayBasedMapData(Map("a" -> 1)), schema)
646+
val input = Literal(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)), schema)
646647
checkEvaluation(
647648
StructsToJson(Map.empty, input),
648649
"""{"a":1}"""
@@ -651,17 +652,18 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
651652

652653
test("to_json - array with maps") {
653654
val inputSchema = ArrayType(MapType(StringType, IntegerType))
654-
val input = new GenericArrayData(ArrayBasedMapData(
655-
Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil)
655+
val input = new GenericArrayData(
656+
ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) ::
657+
ArrayBasedMapData(Map(UTF8String.fromString("b") -> 2)) :: Nil)
656658
val output = """[{"a":1},{"b":2}]"""
657659
checkEvaluation(
658-
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
660+
StructsToJson(Map.empty, Literal(input, inputSchema), gmtId),
659661
output)
660662
}
661663

662664
test("to_json - array with single map") {
663665
val inputSchema = ArrayType(MapType(StringType, IntegerType))
664-
val input = new GenericArrayData(ArrayBasedMapData(Map("a" -> 1)) :: Nil)
666+
val input = new GenericArrayData(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) :: Nil)
665667
val output = """[{"a":1}]"""
666668
checkEvaluation(
667669
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.sql.Timestamp
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
2224
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -107,8 +109,8 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
107109
val nullLit = Literal.create(null, NullType)
108110
val floatNullLit = Literal.create(null, FloatType)
109111
val floatLit = Literal.create(1.01f, FloatType)
110-
val timestampLit = Literal.create("2017-04-12", TimestampType)
111-
val decimalLit = Literal.create(10.2, DecimalType(20, 2))
112+
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
113+
val decimalLit = Literal.create(BigDecimal.valueOf(10.2), DecimalType(20, 2))
112114

113115
assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType)
114116
assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.sql.{Date, Timestamp}
20+
import java.sql.Timestamp
2121
import java.util.TimeZone
2222

2323
import org.apache.spark.SparkFunSuite
@@ -32,9 +32,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
3232
val b2 = Literal.create(true, BooleanType)
3333
val i1 = Literal.create(20132983, IntegerType)
3434
val i2 = Literal.create(-20132983, IntegerType)
35-
val l1 = Literal.create(20132983, LongType)
36-
val l2 = Literal.create(-20132983, LongType)
37-
val millis = 1524954911000L;
35+
val l1 = Literal.create(20132983L, LongType)
36+
val l2 = Literal.create(-20132983L, LongType)
37+
val millis = 1524954911000L
3838
// Explicitly choose a time zone, since Date objects can create different values depending on
3939
// local time zone of the machine on which the test is running
4040
val oldDefaultTZ = TimeZone.getDefault
@@ -57,7 +57,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
5757
val dec1 = Literal(Decimal(20132983L, 10, 2))
5858
val dec2 = Literal(Decimal(20132983L, 19, 2))
5959
val dec3 = Literal(Decimal(20132983L, 21, 2))
60-
val list1 = Literal(List(1, 2), ArrayType(IntegerType))
60+
val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
6161
val nullVal = Literal.create(null, IntegerType)
6262

6363
checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
105105
}
106106

107107
test("parse sql expression for duration in microseconds - long") {
108-
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType)))
108+
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2L << 52, LongType)))
109109
assert(dur.isInstanceOf[Long])
110-
assert(dur === (2 << 52))
110+
assert(dur === (2L << 52))
111111
}
112112

113113
test("parse sql expression for duration in microseconds - invalid interval") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,14 @@ class PercentileSuite extends SparkFunSuite {
232232
BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType)
233233

234234
invalidDataTypes.foreach { dataType =>
235-
val percentage = Literal(0.5, dataType)
235+
val percentage = Literal.default(dataType)
236236
val percentile4 = new Percentile(child, percentage)
237-
assertEqual(percentile4.checkInputDataTypes(),
238-
TypeCheckFailure(s"argument 2 requires double type, however, " +
239-
s"'0.5' is of ${dataType.simpleString} type."))
237+
val checkResult = percentile4.checkInputDataTypes()
238+
assert(checkResult.isFailure)
239+
Seq("argument 2 requires double type, however, ",
240+
s"is of ${dataType.simpleString} type.").foreach { errMsg =>
241+
assert(checkResult.asInstanceOf[TypeCheckFailure].message.contains(errMsg))
242+
}
240243
}
241244
}
242245

sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,13 @@ case class AnalyzeColumnCommand(
210210
def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
211211
expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
212212
})
213-
val one = Literal(1, LongType)
213+
val one = Literal(1L, LongType)
214214

215215
// the approximate ndv (num distinct value) should never be larger than the number of rows
216216
val numNonNulls = if (col.nullable) Count(col) else Count(one)
217217
val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls))
218218
val numNulls = Subtract(Count(one), numNonNulls)
219-
val defaultSize = Literal(col.dataType.defaultSize, LongType)
219+
val defaultSize = Literal(col.dataType.defaultSize.toLong, LongType)
220220
val nullArray = Literal(null, ArrayType(LongType))
221221

222222
def fixedLenTypeStruct: CreateNamedStruct = {

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
228228
test("join key rewritten") {
229229
val l = Literal(1L)
230230
val i = Literal(2)
231-
val s = Literal.create(3, ShortType)
231+
val s = Literal.create(3.toShort, ShortType)
232232
val ss = Literal("hello")
233233

234234
assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)

0 commit comments

Comments
 (0)