diff --git a/.gitignore b/.gitignore
index 0e142d1..2e9d929 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,4 @@ project/project/
project/target/
*.DS_Store
/target/
-/project/
+/project/build.properties
\ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a43a9f3..750e4bd 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,2 +1,3 @@
We happily welcome contributions to *Databricks Labs - dataframe-rules-engine*.
-We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes.
\ No newline at end of file
+We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes.
+Please make a fork of this repository and submit a pull request.
\ No newline at end of file
diff --git a/README.md b/README.md
index 788d016..e5e6e04 100644
--- a/README.md
+++ b/README.md
@@ -182,3 +182,21 @@ cd Downloads
git pull repo
sbt clean package
```
+
+## Running tests
+To run tests on the project:
+```
+sbt test
+```
+
+Make sure that your JAVA_HOME is setup for sbt to run the tests properly. You will need JDK 8 as Spark does
+not support newer versions of the JDK.
+
+## Test reports for test coverage
+To get test coverage report for the project:
+```
+sbt jacoco
+```
+
+The test reports can be found in target/scala-/jacoco/
+
diff --git a/build.sbt b/build.sbt
index abdd627..9115f60 100644
--- a/build.sbt
+++ b/build.sbt
@@ -9,6 +9,26 @@ scalacOptions ++= Seq("-Xmax-classfile-name", "78")
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.0"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0"
+libraryDependencies += "org.scalactic" %% "scalactic" % "3.1.1"
+libraryDependencies += "org.scalatest" %% "scalatest" % "3.1.1" % "test"
+
+lazy val excludes = jacocoExcludes in Test := Seq()
+
+lazy val jacoco = jacocoReportSettings in test :=JacocoReportSettings(
+ "Jacoco Scala Example Coverage Report",
+ None,
+ JacocoThresholds (branch = 100),
+ Seq(JacocoReportFormats.ScalaHTML,
+ JacocoReportFormats.CSV),
+ "utf-8")
+
+val jacocoSettings = Seq(jacoco)
+lazy val jse = (project in file (".")).settings(jacocoSettings: _*)
+
+fork in Test := true
+javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled")
+testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, "-oD")
+
lazy val commonSettings = Seq(
version := "0.1.1",
diff --git a/project/plugins.sbt b/project/plugins.sbt
new file mode 100644
index 0000000..8e606bc
--- /dev/null
+++ b/project/plugins.sbt
@@ -0,0 +1 @@
+addSbtPlugin("com.github.sbt" % "sbt-jacoco" % "3.0.3")
\ No newline at end of file
diff --git a/src/main/scala/com/databricks/labs/validation/QuickTest.scala b/src/main/scala/com/databricks/labs/validation/QuickTest.scala
new file mode 100644
index 0000000..233306b
--- /dev/null
+++ b/src/main/scala/com/databricks/labs/validation/QuickTest.scala
@@ -0,0 +1,37 @@
+package com.databricks.labs.validation
+
+import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef}
+import com.databricks.labs.validation.utils.SparkSessionWrapper
+import org.apache.spark.sql.functions._
+
+object QuickTest extends App with SparkSessionWrapper {
+
+ import spark.implicits._
+
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 9)
+ ).toDF("retail_price", "scan_price", "cost")
+
+ Rule("Reasonable_sku_counts", count(col("sku")), Bounds(lower = 20.0, upper = 200.0))
+
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99))
+ )
+
+ val someRuleSet = RuleSet(testDF)
+ .add(Rule("retail_pass", col("retail_price"), Bounds(lower = 1.0, upper = 7.0)))
+ .add(Rule("retail_agg_pass_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 7.1)))
+ .add(Rule("retail_agg_pass_low", min(col("retail_price")), Bounds(lower = 0.0, upper = 7.0)))
+ .add(Rule("retail_fail_low", col("retail_price"), Bounds(lower = 1.1, upper = 7.0)))
+ .add(Rule("retail_fail_high", col("retail_price"), Bounds(lower = 0.0, upper = 6.9)))
+ .add(Rule("retail_agg_fail_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 6.9)))
+ .add(Rule("retail_agg_fail_low", min(col("retail_price")), Bounds(lower = 1.1, upper = 7.0)))
+ .addMinMaxRules(minMaxPriceDefs: _*)
+ val (rulesReport, passed) = someRuleSet.validate()
+
+ testDF.show(20, false)
+ rulesReport.show(20, false)
+}
diff --git a/src/main/scala/com/databricks/labs/validation/Rule.scala b/src/main/scala/com/databricks/labs/validation/Rule.scala
index 4d0add3..46bf413 100644
--- a/src/main/scala/com/databricks/labs/validation/Rule.scala
+++ b/src/main/scala/com/databricks/labs/validation/Rule.scala
@@ -20,7 +20,7 @@ class Rule {
private var _validNumerics: Array[Double] = _
private var _validStrings: Array[String] = _
private var _dateTimeLogic: Column = _
- private var _ruleType: String = _
+ private var _ruleType: RuleType.Value = _
private var _isAgg: Boolean = _
private def setRuleName(value: String): this.type = {
@@ -69,7 +69,7 @@ class Rule {
this
}
- private def setRuleType(value: String): this.type = {
+ private def setRuleType(value: RuleType.Value): this.type = {
_ruleType = value
this
}
@@ -99,7 +99,7 @@ class Rule {
def dateTimeLogic: Column = _dateTimeLogic
- def ruleType: String = _ruleType
+ def ruleType: RuleType.Value = _ruleType
private[validation] def isAgg: Boolean = _isAgg
@@ -121,7 +121,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setBoundaries(boundaries)
- .setRuleType("bounds")
+ .setRuleType(RuleType.ValidateBounds)
.setIsAgg
}
@@ -135,7 +135,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidNumerics(validNumerics)
- .setRuleType("validNumerics")
+ .setRuleType(RuleType.ValidateNumerics)
.setIsAgg
}
@@ -149,7 +149,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
- .setRuleType("validNumerics")
+ .setRuleType(RuleType.ValidateNumerics)
.setIsAgg
}
@@ -163,7 +163,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
- .setRuleType("validNumerics")
+ .setRuleType(RuleType.ValidateNumerics)
.setIsAgg
}
@@ -177,7 +177,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidStrings(validStrings)
- .setRuleType("validStrings")
+ .setRuleType(RuleType.ValidateStrings)
.setIsAgg
}
diff --git a/src/main/scala/com/databricks/labs/validation/RuleType.scala b/src/main/scala/com/databricks/labs/validation/RuleType.scala
new file mode 100644
index 0000000..a4c5521
--- /dev/null
+++ b/src/main/scala/com/databricks/labs/validation/RuleType.scala
@@ -0,0 +1,12 @@
+package com.databricks.labs.validation
+
+/**
+ * Definition of the Rule Types as an Enumeration for better type matching
+ */
+object RuleType extends Enumeration {
+ val ValidateBounds = Value("bounds")
+ val ValidateNumerics = Value("validNumerics")
+ val ValidateStrings = Value("validStrings")
+ val ValidateDateTime = Value("validDateTime")
+ val ValidateComplex = Value("complex")
+}
diff --git a/src/main/scala/com/databricks/labs/validation/Validator.scala b/src/main/scala/com/databricks/labs/validation/Validator.scala
index 983ba58..1f06009 100644
--- a/src/main/scala/com/databricks/labs/validation/Validator.scala
+++ b/src/main/scala/com/databricks/labs/validation/Validator.scala
@@ -13,11 +13,11 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
import spark.implicits._
- private val boundaryRules = ruleSet.getRules.filter(_.ruleType == "bounds")
- private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == "validNumerics" ||
- rule.ruleType == "validStrings")
- private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == "dateTime")
- private val complexRules = ruleSet.getRules.filter(_.ruleType == "complex")
+ private val boundaryRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateBounds)
+ private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == RuleType.ValidateNumerics ||
+ rule.ruleType == RuleType.ValidateStrings)
+ private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateDateTime)
+ private val complexRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateComplex)
private val byCols = ruleSet.getGroupBys map col
/**
@@ -36,15 +36,15 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
*/
private def buildValidationsByType(rule: Rule): Column = {
val nulls = mutable.Map[String, Column](
- "bounds" -> lit(null).cast(ArrayType(DoubleType)).alias("bounds"),
- "validNumerics" -> lit(null).cast(ArrayType(DoubleType)).alias("validNumerics"),
- "validStrings" -> lit(null).cast(ArrayType(StringType)).alias("validStrings"),
- "validDate" -> lit(null).cast(LongType).alias("validDate")
+ RuleType.ValidateBounds.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateBounds.toString),
+ RuleType.ValidateNumerics.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateNumerics.toString),
+ RuleType.ValidateStrings.toString -> lit(null).cast(ArrayType(StringType)).alias(RuleType.ValidateStrings.toString),
+ RuleType.ValidateDateTime.toString -> lit(null).cast(LongType).alias(RuleType.ValidateDateTime.toString)
)
rule.ruleType match {
- case "bounds" => nulls("bounds") = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias("bounds")
- case "validNumerics" => nulls("validNumerics") = lit(rule.validNumerics).alias("validNumerics")
- case "validStrings" => nulls("validStrings") = lit(rule.validStrings).alias("validStrings")
+ case RuleType.ValidateBounds => nulls(RuleType.ValidateBounds.toString) = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias(RuleType.ValidateBounds.toString)
+ case RuleType.ValidateNumerics => nulls(RuleType.ValidateNumerics.toString) = lit(rule.validNumerics).alias(RuleType.ValidateNumerics.toString)
+ case RuleType.ValidateStrings => nulls(RuleType.ValidateStrings.toString) = lit(rule.validStrings).alias(RuleType.ValidateStrings.toString)
}
val validationsByType = nulls.toMap.values.toSeq
struct(
@@ -61,7 +61,7 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
private def buildOutputStruct(rule: Rule, results: Seq[Column]): Column = {
struct(
lit(rule.ruleName).alias("Rule_Name"),
- lit(rule.ruleType).alias("Rule_Type"),
+ lit(rule.ruleType.toString).alias("Rule_Type"),
buildValidationsByType(rule),
struct(results: _*).alias("Results")
).alias("Validation")
@@ -101,8 +101,27 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
// Results must have Invalid_Count & Failed
rule.ruleType match {
- case "bounds" =>
+ case RuleType.ValidateBounds =>
+ // Rule evaluation for NON-AGG RULES ONLY
val invalid = rule.inputColumn < rule.boundaries.lower || rule.inputColumn > rule.boundaries.upper
+ // This is the first select it must come before subsequent selects as it aliases the original column name
+ // to that of the rule name. ADDITIONALLY, this evaluates the boundary rule WHEN the input col is not an Agg.
+ // This can be confusing because for Non-agg columns it renames the column to the rule_name AND returns a 0
+ // or 1 (not the original value)
+ // IF the rule is NOT an AGG then the column is simply aliased to the rule name and no evaluation takes place
+ // here.
+ val first = if (!rule.isAgg) { // Not Agg
+ sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
+ } else { // Is Agg
+ rule.inputColumn.alias(rule.ruleName)
+ }
+ // WHEN RULE IS AGG -- this is where the evaluation happens. The input column was renamed to the name of the
+ // rule in the required previous select.
+ // IMPORTANT: REMEMBER - that agg expressions evaluate to a single output value thus the invalid_count in
+ // cases where agg is used cannot be > 1 since the sum of a single value cannot exceed 1.
+
+ // WHEN RULE NOT AGG - determine if the result of "first" select (0 or 1) is > 0, if it is, the rule has
+ // failed since the sum(1 or more 1s) means that 1 or more rows have failed thus the rule has failed
val failed = if (rule.isAgg) {
when(
col(rule.ruleName) < rule.boundaries.lower || col(rule.ruleName) > rule.boundaries.upper, true)
@@ -110,19 +129,14 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
} else{
when(col(rule.ruleName) > 0,true).otherwise(false).alias("Failed")
}
- val first = if (!rule.isAgg) { // Not Agg
- sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
- } else { // Is Agg
- rule.inputColumn.alias(rule.ruleName)
- }
val results = if (rule.isAgg) {
Seq(when(failed, 1).otherwise(0).cast(LongType).alias("Invalid_Count"), failed)
} else {
Seq(col(rule.ruleName).cast(LongType).alias("Invalid_Count"), failed)
}
Selects(buildOutputStruct(rule, results), first)
- case x if x == "validNumerics" || x == "validStrings" =>
- val invalid = if (x == "validNumerics") {
+ case x if x == RuleType.ValidateNumerics || x == RuleType.ValidateStrings =>
+ val invalid = if (x == RuleType.ValidateNumerics) {
expr(s"size(array_except(${rule.ruleName}," +
s"array(${rule.validNumerics.mkString("D,")}D)))")
} else {
@@ -134,8 +148,8 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
val first = collect_set(rule.inputColumn).alias(rule.ruleName)
val results = Seq(invalid.cast(LongType).alias("Invalid_Count"), failed)
Selects(buildOutputStruct(rule, results), first)
- case "validDate" => ??? // TODO
- case "complex" => ??? // TODO
+ case RuleType.ValidateDateTime => ??? // TODO
+ case RuleType.ValidateComplex => ??? // TODO
}
})
}
diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties
new file mode 100644
index 0000000..a87661a
--- /dev/null
+++ b/src/test/resources/log4j.properties
@@ -0,0 +1,13 @@
+# Set everything to be logged to the console
+log4j.rootCategory=ERROR, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+
+# Settings to quiet third party logs that are too verbose
+log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN
+log4j.logger.org.apache.spark.sql.SparkSession$Builder=ERROR
\ No newline at end of file
diff --git a/src/test/scala/com/databricks/labs/validation/SparkSessionFixture.scala b/src/test/scala/com/databricks/labs/validation/SparkSessionFixture.scala
new file mode 100644
index 0000000..094151d
--- /dev/null
+++ b/src/test/scala/com/databricks/labs/validation/SparkSessionFixture.scala
@@ -0,0 +1,12 @@
+package com.databricks.labs.validation
+
+import org.apache.spark.sql.SparkSession
+
+trait SparkSessionFixture {
+ lazy val spark = SparkSession
+ .builder()
+ .master("local")
+ .appName("spark session")
+ .config("spark.sql.shuffle.partitions", "1")
+ .getOrCreate()
+}
diff --git a/src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala b/src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala
new file mode 100644
index 0000000..2ac3080
--- /dev/null
+++ b/src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala
@@ -0,0 +1,278 @@
+package com.databricks.labs.validation
+
+import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef}
+import org.apache.spark.sql.functions.{col, min}
+
+case class ValidationValue(validDateTime: java.lang.Long, validNumerics: Array[Double], bounds: Array[Double], validStrings: Array[String])
+
+class ValidatorTestSuite extends org.scalatest.FunSuite with SparkSessionFixture {
+
+ import spark.implicits._
+ spark.sparkContext.setLogLevel("ERROR")
+
+ test("The input dataframe should have no rule failures on MinMaxRule") {
+ val expectedDF = Seq(
+ ("MinMax_Cost_Generated_max","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false),
+ ("MinMax_Cost_Generated_min","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false),
+ ("MinMax_Cost_manual_max","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false),
+ ("MinMax_Cost_manual_min","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false),
+ ("MinMax_Cost_max","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false),
+ ("MinMax_Cost_min","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false),
+ ("MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false)
+ ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+ val data = Seq()
+ // 2 per rule so 2 MinMax_Sku_Price + 2 MinMax_Scan_Price + 2 MinMax_Cost + 2 MinMax_Cost_Generated
+ // + 2 MinMax_Cost_manual = 10 rules
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 9)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Cost", col("cost"), Bounds(0.0, 12.0))
+ )
+
+ // Generate the array of Rules from the minmax generator
+ val rulesArray = RuleSet.generateMinMaxRules(MinMaxRuleDef("MinMax_Cost_Generated", col("cost"), Bounds(0.0, 12.0)))
+
+ val someRuleSet = RuleSet(testDF)
+ someRuleSet.addMinMaxRules(minMaxPriceDefs: _*)
+ someRuleSet.addMinMaxRules("MinMax_Cost_manual", col("cost"), Bounds(0.0,12.0))
+ someRuleSet.add(rulesArray)
+ val (rulesReport, passed) = someRuleSet.validate()
+ assert(rulesReport.except(expectedDF).count() == 0)
+ assert(passed)
+ assert(rulesReport.count() == 10)
+ }
+
+ test("The input rule should have 1 invalid count for MinMax_Scan_Price_Minus_Retail_Price_min and max for failing complex type.") {
+ val expectedDF = Seq(
+ ("MinMax_Retail_Price_Minus_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),1,true),
+ ("MinMax_Retail_Price_Minus_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),1,true),
+ ("MinMax_Scan_Price_Minus_Retail_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Scan_Price_Minus_Retail_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false)
+ ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 9)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99))
+ )
+
+ // Generate the array of Rules from the minmax generator
+ val someRuleSet = RuleSet(testDF)
+ someRuleSet.addMinMaxRules(minMaxPriceDefs: _*)
+ val (rulesReport, passed) = someRuleSet.validate()
+ assert(rulesReport.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(!passed)
+ assert(rulesReport.count() == 4)
+ }
+
+ test("The input rule should have 3 invalid count for failing aggregate type.") {
+ val expectedDF = Seq(
+ ("MinMax_Min_Retail_Price","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Min_Scan_Price","bounds",ValidationValue(null,null,Array(3.0, 29.99),null),1,true)
+ ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 9)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Seq(
+ Rule("MinMax_Min_Retail_Price", min("retail_price"), Bounds(0.0, 29.99)),
+ Rule("MinMax_Min_Scan_Price", min("scan_price"), Bounds(3.0, 29.99))
+ )
+
+
+ // Generate the array of Rules from the minmax generator
+ val someRuleSet = RuleSet(testDF)
+ someRuleSet.add(minMaxPriceDefs)
+ val (rulesReport, passed) = someRuleSet.validate()
+ assert(rulesReport.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(!passed)
+ assert(rulesReport.count() == 2)
+ }
+
+ test("The input dataframe should have exactly 1 rule failure on MinMaxRule") {
+ val expectedDF = Seq(
+ ("MinMax_Cost_max","bounds",ValidationValue(null,null,Array(0.0, 12.00),null),1,true),
+ ("MinMax_Cost_min","bounds",ValidationValue(null,null,Array(0.0, 12.00),null),0,false),
+ ("MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ ("MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false)
+ ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 99)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Cost", col("cost"), Bounds(0.0, 12.0))
+ )
+ // Generate the array of Rules from the minmax generator
+
+ val someRuleSet = RuleSet(testDF)
+ someRuleSet.addMinMaxRules(minMaxPriceDefs: _*)
+ val (rulesReport, passed) = someRuleSet.validate()
+ val failedResults = rulesReport.filter(rulesReport("Invalid_Count") > 0).collect()
+ assert(failedResults.length == 1)
+ assert(rulesReport.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(failedResults(0)(0) == "MinMax_Cost_max")
+ assert(!passed)
+ }
+
+ test("The DF in the rulesset object is the same as the input test df") {
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 99)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Cost", col("cost"), Bounds(0.0, 12.0))
+ )
+ // Generate the array of Rules from the minmax generator
+
+ val someRuleSet = RuleSet(testDF)
+ someRuleSet.addMinMaxRules(minMaxPriceDefs: _*)
+ val rulesDf = someRuleSet.getDf
+ assert(testDF.except(rulesDf).count() == 0)
+ }
+
+ test("The group by columns are the correct group by clauses in the validation") {
+ val expectedDF = Seq(
+ (3,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (6,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (3,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (6,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (3,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (6,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (3,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (6,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false)
+ ).toDF("cost","Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+ // 2 groups so count of the rules should yield (2 minmax rules * 2 columns) * 2 groups in cost (8 rows)
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 3)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)),
+ MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99))
+ )
+
+ val someRuleSet = RuleSet(testDF, "cost")
+ someRuleSet.addMinMaxRules(minMaxPriceDefs: _*)
+ val groupBys = someRuleSet.getGroupBys
+ val (groupByValidated, passed) = someRuleSet.validate()
+
+ assert(groupBys.length == 1)
+ assert(groupBys.head == "cost")
+ assert(someRuleSet.isGrouped)
+ assert(passed)
+ assert(groupByValidated.count() == 8)
+ assert(groupByValidated.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(groupByValidated.filter(groupByValidated("Invalid_Count") > 0).count() == 0)
+ assert(groupByValidated.filter(groupByValidated("Failed") === true).count() == 0)
+ }
+
+ test("The group by columns are with rules failing the validation") {
+ val expectedDF = Seq(
+ (3,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true),
+ (6,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true),
+ (3,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true),
+ (6,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true),
+ (3,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (6,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (3,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
+ (6,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false)
+ ).toDF("cost","Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+ // 2 groups so count of the rules should yield (2 minmax rules * 2 columns) * 2 groups in cost (8 rows)
+ val testDF = Seq(
+ (1, 2, 3),
+ (4, 5, 6),
+ (7, 8, 3)
+ ).toDF("retail_price", "scan_price", "cost")
+ val minMaxPriceDefs = Array(
+ MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 0.0)),
+ MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99))
+ )
+
+ val someRuleSet = RuleSet(testDF, "cost")
+ someRuleSet.addMinMaxRules(minMaxPriceDefs: _*)
+ val groupBys = someRuleSet.getGroupBys
+ val (groupByValidated, passed) = someRuleSet.validate()
+
+ assert(groupBys.length == 1, "Group by length is not 1")
+ assert(groupBys.head == "cost", "Group by column is not cost")
+ assert(someRuleSet.isGrouped)
+ assert(!passed, "Rule set did not fail.")
+ assert(groupByValidated.count() == 8, "Rule count should be 8")
+ assert(groupByValidated.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(groupByValidated.filter(groupByValidated("Invalid_Count") > 0).count() == 4, "Invalid count is not 4.")
+ assert(groupByValidated.filter(groupByValidated("Failed") === true).count() == 4, "Failed count is not 4.")
+ }
+
+ test("Validate list of values with numeric types, string types and long types.") {
+
+ val testDF = Seq(
+ ("food_a", 2.51, 3, 111111111111111L),
+ ("food_b", 5.11, 6, 211111111111111L),
+ ("food_c", 8.22, 99, 311111111111111L)
+ ).toDF("product_name", "scan_price", "cost", "id")
+
+ val numericLovExpectedDF = Seq(
+ ("CheckIfCostIsInLOV","validNumerics",ValidationValue(null,Array(3,6,99),null,null),0,false),
+ ("CheckIfScanPriceIsInLOV","validNumerics",ValidationValue(null,Array(2.51,5.11,8.22),null,null),0,false),
+ ("CheckIfIdIsInLOV","validNumerics",ValidationValue(null,Array(111111111111111L,211111111111111L,311111111111111L),null,null),0,false)
+ ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+ val numericRules = Array(
+ Rule("CheckIfCostIsInLOV", col("cost"), Array(3,6,99)),
+ Rule("CheckIfScanPriceIsInLOV", col("scan_price"), Array(2.51,5.11,8.22)),
+ Rule("CheckIfIdIsInLOV", col("id"), Array(111111111111111L,211111111111111L,311111111111111L))
+ )
+ // Generate the array of Rules from the minmax generator
+
+ val numericRuleSet = RuleSet(testDF)
+ numericRuleSet.add(numericRules)
+ val (numericValidated, numericPassed) = numericRuleSet.validate()
+ assert(numericRules.map(_.ruleType == RuleType.ValidateNumerics).reduce(_ && _), "Not every value is validate numerics.")
+ assert(numericRules.map(_.boundaries == null).reduce(_ && _), "Boundaries are not null.")
+ assert(numericPassed)
+ assert(numericValidated.except(numericLovExpectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(numericValidated.filter(numericValidated("Invalid_Count") > 0).count() == 0)
+ assert(numericValidated.filter(numericValidated("Failed") === true).count() == 0)
+
+ val stringRule = Rule("CheckIfProductNameInLOV", col("product_name"), Array("food_a","food_b","food_c"))
+ // Generate the array of Rules from the minmax generator
+
+ val stringLovExpectedDF = Seq(
+ ("CheckIfProductNameInLOV","validStrings",ValidationValue(null,null,null,Array("food_a", "food_b", "food_c")),0,false)
+ ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")
+
+ val stringRuleSet = RuleSet(testDF)
+ stringRuleSet.add(stringRule)
+ val (stringValidated, stringPassed) = stringRuleSet.validate()
+ assert(stringRule.ruleType == RuleType.ValidateStrings)
+ assert(stringRule.boundaries == null)
+ assert(stringPassed)
+ assert(stringValidated.except(stringLovExpectedDF).count() == 0, "Expected df is not equal to the returned rules report.")
+ assert(stringValidated.filter(stringValidated("Invalid_Count") > 0).count() == 0)
+ assert(stringValidated.filter(stringValidated("Failed") === true).count() == 0)
+ }
+
+
+}