Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,27 @@ object LiteralGenerator {
lazy val longLiteralGen: Gen[Literal] =
for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType)

// The floatLiteralGen and doubleLiteralGen will 50% of the time yield arbitrary values
// and 50% of the time will yield some special values that are more likely to reveal
// corner cases. This behavior is similar to the integral value generators.
lazy val floatLiteralGen: Gen[Literal] =
for {
f <- Gen.chooseNum(Float.MinValue / 2, Float.MaxValue / 2,
Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity)
f <- Gen.oneOf(
Gen.oneOf(
Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity, Float.MinPositiveValue,
Float.MaxValue, -Float.MaxValue, 0.0f, -0.0f, 1.0f, -1.0f),
Arbitrary.arbFloat.arbitrary
)
} yield Literal.create(f, FloatType)

lazy val doubleLiteralGen: Gen[Literal] =
for {
f <- Gen.chooseNum(Double.MinValue / 2, Double.MaxValue / 2,
Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity)
f <- Gen.oneOf(
Gen.oneOf(
Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity, Double.MinPositiveValue,
Double.MaxValue, -Double.MaxValue, 0.0, -0.0, 1.0, -1.0),
Arbitrary.arbDouble.arbitrary
)
} yield Literal.create(f, DoubleType)

// TODO cache the generated data
Expand Down Expand Up @@ -167,6 +178,8 @@ object LiteralGenerator {
case BinaryType => binaryLiteralGen
case CalendarIntervalType => calendarIntervalLiterGen
case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale)
case ArrayType(et, _) => randomGen(et).map(
lit => Literal.create(Array(lit.value), ArrayType(et)))
case dt => throw new IllegalArgumentException(s"not supported type $dt")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
Expand Down Expand Up @@ -91,6 +91,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
DataTypeTestUtils.propertyCheckSupported.foreach { dt =>
checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt)
checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt)

val arrayType = ArrayType(dt)
checkConsistencyBetweenInterpretedAndCodegen(EqualTo, arrayType, arrayType)
checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, arrayType, arrayType)
}
}

Expand Down Expand Up @@ -496,6 +500,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualTo(infinity, infinity), true)
}

private def testEquality(literals: Seq[Literal]): Unit = {
literals.foreach(left => {
literals.foreach(right => {
checkEvaluation(EqualTo(left, right), true)
checkEvaluation(EqualNullSafe(left, right), true)

val leftArray = Literal.create(Array(left.value), ArrayType(left.dataType))
val rightArray = Literal.create(Array(right.value), ArrayType(right.dataType))
checkEvaluation(EqualTo(leftArray, rightArray), true)
checkEvaluation(EqualNullSafe(leftArray, rightArray), true)

val leftStruct = Literal.create(
Row(left.value), new StructType().add("a", left.dataType))
val rightStruct = Literal.create(
Row(right.value), new StructType().add("a", right.dataType))
checkEvaluation(EqualTo(leftStruct, rightStruct), true)
checkEvaluation(EqualNullSafe(leftStruct, rightStruct), true)
})
})
}

test("SPARK-32688: 0.0 and -0.0 should be equal") {
testEquality(Seq(Literal(0.0), Literal(-0.0)))
testEquality(Seq(Literal(0.0f), Literal(-0.0f)))
}

test("SPARK-22693: InSet should not use global variables") {
val ctx = new CodegenContext
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
Expand Down