diff --git a/.gitignore b/.gitignore index 0e142d1..2e9d929 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ project/project/ project/target/ *.DS_Store /target/ -/project/ +/project/build.properties \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a43a9f3..750e4bd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,2 +1,3 @@ We happily welcome contributions to *Databricks Labs - dataframe-rules-engine*. -We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes. \ No newline at end of file +We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes. +Please make a fork of this repository and submit a pull request. \ No newline at end of file diff --git a/README.md b/README.md index 788d016..e5e6e04 100644 --- a/README.md +++ b/README.md @@ -182,3 +182,21 @@ cd Downloads git pull repo sbt clean package ``` + +## Running tests +To run tests on the project:
+``` +sbt test +``` + +Make sure that your JAVA_HOME is setup for sbt to run the tests properly. You will need JDK 8 as Spark does +not support newer versions of the JDK. + +## Test reports for test coverage +To get test coverage report for the project:
+``` +sbt jacoco +``` + +The test reports can be found in target/scala-/jacoco/ + diff --git a/build.sbt b/build.sbt index abdd627..9115f60 100644 --- a/build.sbt +++ b/build.sbt @@ -9,6 +9,26 @@ scalacOptions ++= Seq("-Xmax-classfile-name", "78") libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.0" libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0" +libraryDependencies += "org.scalactic" %% "scalactic" % "3.1.1" +libraryDependencies += "org.scalatest" %% "scalatest" % "3.1.1" % "test" + +lazy val excludes = jacocoExcludes in Test := Seq() + +lazy val jacoco = jacocoReportSettings in test :=JacocoReportSettings( + "Jacoco Scala Example Coverage Report", + None, + JacocoThresholds (branch = 100), + Seq(JacocoReportFormats.ScalaHTML, + JacocoReportFormats.CSV), + "utf-8") + +val jacocoSettings = Seq(jacoco) +lazy val jse = (project in file (".")).settings(jacocoSettings: _*) + +fork in Test := true +javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled") +testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, "-oD") + lazy val commonSettings = Seq( version := "0.1.1", diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..8e606bc --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("com.github.sbt" % "sbt-jacoco" % "3.0.3") \ No newline at end of file diff --git a/src/main/scala/com/databricks/labs/validation/QuickTest.scala b/src/main/scala/com/databricks/labs/validation/QuickTest.scala new file mode 100644 index 0000000..233306b --- /dev/null +++ b/src/main/scala/com/databricks/labs/validation/QuickTest.scala @@ -0,0 +1,37 @@ +package com.databricks.labs.validation + +import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef} +import com.databricks.labs.validation.utils.SparkSessionWrapper +import org.apache.spark.sql.functions._ + +object QuickTest extends App with SparkSessionWrapper { + + import spark.implicits._ + + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 9) + ).toDF("retail_price", "scan_price", "cost") + + Rule("Reasonable_sku_counts", count(col("sku")), Bounds(lower = 20.0, upper = 200.0)) + + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99)) + ) + + val someRuleSet = RuleSet(testDF) + .add(Rule("retail_pass", col("retail_price"), Bounds(lower = 1.0, upper = 7.0))) + .add(Rule("retail_agg_pass_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 7.1))) + .add(Rule("retail_agg_pass_low", min(col("retail_price")), Bounds(lower = 0.0, upper = 7.0))) + .add(Rule("retail_fail_low", col("retail_price"), Bounds(lower = 1.1, upper = 7.0))) + .add(Rule("retail_fail_high", col("retail_price"), Bounds(lower = 0.0, upper = 6.9))) + .add(Rule("retail_agg_fail_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 6.9))) + .add(Rule("retail_agg_fail_low", min(col("retail_price")), Bounds(lower = 1.1, upper = 7.0))) + .addMinMaxRules(minMaxPriceDefs: _*) + val (rulesReport, passed) = someRuleSet.validate() + + testDF.show(20, false) + rulesReport.show(20, false) +} diff --git a/src/main/scala/com/databricks/labs/validation/Rule.scala b/src/main/scala/com/databricks/labs/validation/Rule.scala index 4d0add3..46bf413 100644 --- a/src/main/scala/com/databricks/labs/validation/Rule.scala +++ b/src/main/scala/com/databricks/labs/validation/Rule.scala @@ -20,7 +20,7 @@ class Rule { private var _validNumerics: Array[Double] = _ private var _validStrings: Array[String] = _ private var _dateTimeLogic: Column = _ - private var _ruleType: String = _ + private var _ruleType: RuleType.Value = _ private var _isAgg: Boolean = _ private def setRuleName(value: String): this.type = { @@ -69,7 +69,7 @@ class Rule { this } - private def setRuleType(value: String): this.type = { + private def setRuleType(value: RuleType.Value): this.type = { _ruleType = value this } @@ -99,7 +99,7 @@ class Rule { def dateTimeLogic: Column = _dateTimeLogic - def ruleType: String = _ruleType + def ruleType: RuleType.Value = _ruleType private[validation] def isAgg: Boolean = _isAgg @@ -121,7 +121,7 @@ object Rule { .setRuleName(ruleName) .setColumn(column) .setBoundaries(boundaries) - .setRuleType("bounds") + .setRuleType(RuleType.ValidateBounds) .setIsAgg } @@ -135,7 +135,7 @@ object Rule { .setRuleName(ruleName) .setColumn(column) .setValidNumerics(validNumerics) - .setRuleType("validNumerics") + .setRuleType(RuleType.ValidateNumerics) .setIsAgg } @@ -149,7 +149,7 @@ object Rule { .setRuleName(ruleName) .setColumn(column) .setValidNumerics(validNumerics.map(_.toString.toDouble)) - .setRuleType("validNumerics") + .setRuleType(RuleType.ValidateNumerics) .setIsAgg } @@ -163,7 +163,7 @@ object Rule { .setRuleName(ruleName) .setColumn(column) .setValidNumerics(validNumerics.map(_.toString.toDouble)) - .setRuleType("validNumerics") + .setRuleType(RuleType.ValidateNumerics) .setIsAgg } @@ -177,7 +177,7 @@ object Rule { .setRuleName(ruleName) .setColumn(column) .setValidStrings(validStrings) - .setRuleType("validStrings") + .setRuleType(RuleType.ValidateStrings) .setIsAgg } diff --git a/src/main/scala/com/databricks/labs/validation/RuleType.scala b/src/main/scala/com/databricks/labs/validation/RuleType.scala new file mode 100644 index 0000000..a4c5521 --- /dev/null +++ b/src/main/scala/com/databricks/labs/validation/RuleType.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.validation + +/** + * Definition of the Rule Types as an Enumeration for better type matching + */ +object RuleType extends Enumeration { + val ValidateBounds = Value("bounds") + val ValidateNumerics = Value("validNumerics") + val ValidateStrings = Value("validStrings") + val ValidateDateTime = Value("validDateTime") + val ValidateComplex = Value("complex") +} diff --git a/src/main/scala/com/databricks/labs/validation/Validator.scala b/src/main/scala/com/databricks/labs/validation/Validator.scala index 983ba58..1f06009 100644 --- a/src/main/scala/com/databricks/labs/validation/Validator.scala +++ b/src/main/scala/com/databricks/labs/validation/Validator.scala @@ -13,11 +13,11 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper { import spark.implicits._ - private val boundaryRules = ruleSet.getRules.filter(_.ruleType == "bounds") - private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == "validNumerics" || - rule.ruleType == "validStrings") - private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == "dateTime") - private val complexRules = ruleSet.getRules.filter(_.ruleType == "complex") + private val boundaryRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateBounds) + private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == RuleType.ValidateNumerics || + rule.ruleType == RuleType.ValidateStrings) + private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateDateTime) + private val complexRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateComplex) private val byCols = ruleSet.getGroupBys map col /** @@ -36,15 +36,15 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper { */ private def buildValidationsByType(rule: Rule): Column = { val nulls = mutable.Map[String, Column]( - "bounds" -> lit(null).cast(ArrayType(DoubleType)).alias("bounds"), - "validNumerics" -> lit(null).cast(ArrayType(DoubleType)).alias("validNumerics"), - "validStrings" -> lit(null).cast(ArrayType(StringType)).alias("validStrings"), - "validDate" -> lit(null).cast(LongType).alias("validDate") + RuleType.ValidateBounds.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateBounds.toString), + RuleType.ValidateNumerics.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateNumerics.toString), + RuleType.ValidateStrings.toString -> lit(null).cast(ArrayType(StringType)).alias(RuleType.ValidateStrings.toString), + RuleType.ValidateDateTime.toString -> lit(null).cast(LongType).alias(RuleType.ValidateDateTime.toString) ) rule.ruleType match { - case "bounds" => nulls("bounds") = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias("bounds") - case "validNumerics" => nulls("validNumerics") = lit(rule.validNumerics).alias("validNumerics") - case "validStrings" => nulls("validStrings") = lit(rule.validStrings).alias("validStrings") + case RuleType.ValidateBounds => nulls(RuleType.ValidateBounds.toString) = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias(RuleType.ValidateBounds.toString) + case RuleType.ValidateNumerics => nulls(RuleType.ValidateNumerics.toString) = lit(rule.validNumerics).alias(RuleType.ValidateNumerics.toString) + case RuleType.ValidateStrings => nulls(RuleType.ValidateStrings.toString) = lit(rule.validStrings).alias(RuleType.ValidateStrings.toString) } val validationsByType = nulls.toMap.values.toSeq struct( @@ -61,7 +61,7 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper { private def buildOutputStruct(rule: Rule, results: Seq[Column]): Column = { struct( lit(rule.ruleName).alias("Rule_Name"), - lit(rule.ruleType).alias("Rule_Type"), + lit(rule.ruleType.toString).alias("Rule_Type"), buildValidationsByType(rule), struct(results: _*).alias("Results") ).alias("Validation") @@ -101,8 +101,27 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper { // Results must have Invalid_Count & Failed rule.ruleType match { - case "bounds" => + case RuleType.ValidateBounds => + // Rule evaluation for NON-AGG RULES ONLY val invalid = rule.inputColumn < rule.boundaries.lower || rule.inputColumn > rule.boundaries.upper + // This is the first select it must come before subsequent selects as it aliases the original column name + // to that of the rule name. ADDITIONALLY, this evaluates the boundary rule WHEN the input col is not an Agg. + // This can be confusing because for Non-agg columns it renames the column to the rule_name AND returns a 0 + // or 1 (not the original value) + // IF the rule is NOT an AGG then the column is simply aliased to the rule name and no evaluation takes place + // here. + val first = if (!rule.isAgg) { // Not Agg + sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName) + } else { // Is Agg + rule.inputColumn.alias(rule.ruleName) + } + // WHEN RULE IS AGG -- this is where the evaluation happens. The input column was renamed to the name of the + // rule in the required previous select. + // IMPORTANT: REMEMBER - that agg expressions evaluate to a single output value thus the invalid_count in + // cases where agg is used cannot be > 1 since the sum of a single value cannot exceed 1. + + // WHEN RULE NOT AGG - determine if the result of "first" select (0 or 1) is > 0, if it is, the rule has + // failed since the sum(1 or more 1s) means that 1 or more rows have failed thus the rule has failed val failed = if (rule.isAgg) { when( col(rule.ruleName) < rule.boundaries.lower || col(rule.ruleName) > rule.boundaries.upper, true) @@ -110,19 +129,14 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper { } else{ when(col(rule.ruleName) > 0,true).otherwise(false).alias("Failed") } - val first = if (!rule.isAgg) { // Not Agg - sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName) - } else { // Is Agg - rule.inputColumn.alias(rule.ruleName) - } val results = if (rule.isAgg) { Seq(when(failed, 1).otherwise(0).cast(LongType).alias("Invalid_Count"), failed) } else { Seq(col(rule.ruleName).cast(LongType).alias("Invalid_Count"), failed) } Selects(buildOutputStruct(rule, results), first) - case x if x == "validNumerics" || x == "validStrings" => - val invalid = if (x == "validNumerics") { + case x if x == RuleType.ValidateNumerics || x == RuleType.ValidateStrings => + val invalid = if (x == RuleType.ValidateNumerics) { expr(s"size(array_except(${rule.ruleName}," + s"array(${rule.validNumerics.mkString("D,")}D)))") } else { @@ -134,8 +148,8 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper { val first = collect_set(rule.inputColumn).alias(rule.ruleName) val results = Seq(invalid.cast(LongType).alias("Invalid_Count"), failed) Selects(buildOutputStruct(rule, results), first) - case "validDate" => ??? // TODO - case "complex" => ??? // TODO + case RuleType.ValidateDateTime => ??? // TODO + case RuleType.ValidateComplex => ??? // TODO } }) } diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties new file mode 100644 index 0000000..a87661a --- /dev/null +++ b/src/test/resources/log4j.properties @@ -0,0 +1,13 @@ +# Set everything to be logged to the console +log4j.rootCategory=ERROR, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN +log4j.logger.org.apache.spark.sql.SparkSession$Builder=ERROR \ No newline at end of file diff --git a/src/test/scala/com/databricks/labs/validation/SparkSessionFixture.scala b/src/test/scala/com/databricks/labs/validation/SparkSessionFixture.scala new file mode 100644 index 0000000..094151d --- /dev/null +++ b/src/test/scala/com/databricks/labs/validation/SparkSessionFixture.scala @@ -0,0 +1,12 @@ +package com.databricks.labs.validation + +import org.apache.spark.sql.SparkSession + +trait SparkSessionFixture { + lazy val spark = SparkSession + .builder() + .master("local") + .appName("spark session") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() +} diff --git a/src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala b/src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala new file mode 100644 index 0000000..2ac3080 --- /dev/null +++ b/src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala @@ -0,0 +1,278 @@ +package com.databricks.labs.validation + +import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef} +import org.apache.spark.sql.functions.{col, min} + +case class ValidationValue(validDateTime: java.lang.Long, validNumerics: Array[Double], bounds: Array[Double], validStrings: Array[String]) + +class ValidatorTestSuite extends org.scalatest.FunSuite with SparkSessionFixture { + + import spark.implicits._ + spark.sparkContext.setLogLevel("ERROR") + + test("The input dataframe should have no rule failures on MinMaxRule") { + val expectedDF = Seq( + ("MinMax_Cost_Generated_max","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false), + ("MinMax_Cost_Generated_min","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false), + ("MinMax_Cost_manual_max","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false), + ("MinMax_Cost_manual_min","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false), + ("MinMax_Cost_max","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false), + ("MinMax_Cost_min","bounds",ValidationValue(null,null,Array(0.0, 12.0),null),0,false), + ("MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false) + ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + val data = Seq() + // 2 per rule so 2 MinMax_Sku_Price + 2 MinMax_Scan_Price + 2 MinMax_Cost + 2 MinMax_Cost_Generated + // + 2 MinMax_Cost_manual = 10 rules + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 9) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Cost", col("cost"), Bounds(0.0, 12.0)) + ) + + // Generate the array of Rules from the minmax generator + val rulesArray = RuleSet.generateMinMaxRules(MinMaxRuleDef("MinMax_Cost_Generated", col("cost"), Bounds(0.0, 12.0))) + + val someRuleSet = RuleSet(testDF) + someRuleSet.addMinMaxRules(minMaxPriceDefs: _*) + someRuleSet.addMinMaxRules("MinMax_Cost_manual", col("cost"), Bounds(0.0,12.0)) + someRuleSet.add(rulesArray) + val (rulesReport, passed) = someRuleSet.validate() + assert(rulesReport.except(expectedDF).count() == 0) + assert(passed) + assert(rulesReport.count() == 10) + } + + test("The input rule should have 1 invalid count for MinMax_Scan_Price_Minus_Retail_Price_min and max for failing complex type.") { + val expectedDF = Seq( + ("MinMax_Retail_Price_Minus_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),1,true), + ("MinMax_Retail_Price_Minus_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),1,true), + ("MinMax_Scan_Price_Minus_Retail_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Scan_Price_Minus_Retail_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false) + ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 9) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99)) + ) + + // Generate the array of Rules from the minmax generator + val someRuleSet = RuleSet(testDF) + someRuleSet.addMinMaxRules(minMaxPriceDefs: _*) + val (rulesReport, passed) = someRuleSet.validate() + assert(rulesReport.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(!passed) + assert(rulesReport.count() == 4) + } + + test("The input rule should have 3 invalid count for failing aggregate type.") { + val expectedDF = Seq( + ("MinMax_Min_Retail_Price","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Min_Scan_Price","bounds",ValidationValue(null,null,Array(3.0, 29.99),null),1,true) + ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 9) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Seq( + Rule("MinMax_Min_Retail_Price", min("retail_price"), Bounds(0.0, 29.99)), + Rule("MinMax_Min_Scan_Price", min("scan_price"), Bounds(3.0, 29.99)) + ) + + + // Generate the array of Rules from the minmax generator + val someRuleSet = RuleSet(testDF) + someRuleSet.add(minMaxPriceDefs) + val (rulesReport, passed) = someRuleSet.validate() + assert(rulesReport.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(!passed) + assert(rulesReport.count() == 2) + } + + test("The input dataframe should have exactly 1 rule failure on MinMaxRule") { + val expectedDF = Seq( + ("MinMax_Cost_max","bounds",ValidationValue(null,null,Array(0.0, 12.00),null),1,true), + ("MinMax_Cost_min","bounds",ValidationValue(null,null,Array(0.0, 12.00),null),0,false), + ("MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + ("MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false) + ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 99) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Cost", col("cost"), Bounds(0.0, 12.0)) + ) + // Generate the array of Rules from the minmax generator + + val someRuleSet = RuleSet(testDF) + someRuleSet.addMinMaxRules(minMaxPriceDefs: _*) + val (rulesReport, passed) = someRuleSet.validate() + val failedResults = rulesReport.filter(rulesReport("Invalid_Count") > 0).collect() + assert(failedResults.length == 1) + assert(rulesReport.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(failedResults(0)(0) == "MinMax_Cost_max") + assert(!passed) + } + + test("The DF in the rulesset object is the same as the input test df") { + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 99) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Cost", col("cost"), Bounds(0.0, 12.0)) + ) + // Generate the array of Rules from the minmax generator + + val someRuleSet = RuleSet(testDF) + someRuleSet.addMinMaxRules(minMaxPriceDefs: _*) + val rulesDf = someRuleSet.getDf + assert(testDF.except(rulesDf).count() == 0) + } + + test("The group by columns are the correct group by clauses in the validation") { + val expectedDF = Seq( + (3,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (6,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (3,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (6,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (3,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (6,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (3,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (6,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false) + ).toDF("cost","Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + // 2 groups so count of the rules should yield (2 minmax rules * 2 columns) * 2 groups in cost (8 rows) + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 3) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 29.99)), + MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)) + ) + + val someRuleSet = RuleSet(testDF, "cost") + someRuleSet.addMinMaxRules(minMaxPriceDefs: _*) + val groupBys = someRuleSet.getGroupBys + val (groupByValidated, passed) = someRuleSet.validate() + + assert(groupBys.length == 1) + assert(groupBys.head == "cost") + assert(someRuleSet.isGrouped) + assert(passed) + assert(groupByValidated.count() == 8) + assert(groupByValidated.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(groupByValidated.filter(groupByValidated("Invalid_Count") > 0).count() == 0) + assert(groupByValidated.filter(groupByValidated("Failed") === true).count() == 0) + } + + test("The group by columns are with rules failing the validation") { + val expectedDF = Seq( + (3,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true), + (6,"MinMax_Sku_Price_max","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true), + (3,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true), + (6,"MinMax_Sku_Price_min","bounds",ValidationValue(null,null,Array(0.0, 0.0),null),1,true), + (3,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (6,"MinMax_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (3,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false), + (6,"MinMax_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false) + ).toDF("cost","Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + // 2 groups so count of the rules should yield (2 minmax rules * 2 columns) * 2 groups in cost (8 rows) + val testDF = Seq( + (1, 2, 3), + (4, 5, 6), + (7, 8, 3) + ).toDF("retail_price", "scan_price", "cost") + val minMaxPriceDefs = Array( + MinMaxRuleDef("MinMax_Sku_Price", col("retail_price"), Bounds(0.0, 0.0)), + MinMaxRuleDef("MinMax_Scan_Price", col("scan_price"), Bounds(0.0, 29.99)) + ) + + val someRuleSet = RuleSet(testDF, "cost") + someRuleSet.addMinMaxRules(minMaxPriceDefs: _*) + val groupBys = someRuleSet.getGroupBys + val (groupByValidated, passed) = someRuleSet.validate() + + assert(groupBys.length == 1, "Group by length is not 1") + assert(groupBys.head == "cost", "Group by column is not cost") + assert(someRuleSet.isGrouped) + assert(!passed, "Rule set did not fail.") + assert(groupByValidated.count() == 8, "Rule count should be 8") + assert(groupByValidated.except(expectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(groupByValidated.filter(groupByValidated("Invalid_Count") > 0).count() == 4, "Invalid count is not 4.") + assert(groupByValidated.filter(groupByValidated("Failed") === true).count() == 4, "Failed count is not 4.") + } + + test("Validate list of values with numeric types, string types and long types.") { + + val testDF = Seq( + ("food_a", 2.51, 3, 111111111111111L), + ("food_b", 5.11, 6, 211111111111111L), + ("food_c", 8.22, 99, 311111111111111L) + ).toDF("product_name", "scan_price", "cost", "id") + + val numericLovExpectedDF = Seq( + ("CheckIfCostIsInLOV","validNumerics",ValidationValue(null,Array(3,6,99),null,null),0,false), + ("CheckIfScanPriceIsInLOV","validNumerics",ValidationValue(null,Array(2.51,5.11,8.22),null,null),0,false), + ("CheckIfIdIsInLOV","validNumerics",ValidationValue(null,Array(111111111111111L,211111111111111L,311111111111111L),null,null),0,false) + ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + val numericRules = Array( + Rule("CheckIfCostIsInLOV", col("cost"), Array(3,6,99)), + Rule("CheckIfScanPriceIsInLOV", col("scan_price"), Array(2.51,5.11,8.22)), + Rule("CheckIfIdIsInLOV", col("id"), Array(111111111111111L,211111111111111L,311111111111111L)) + ) + // Generate the array of Rules from the minmax generator + + val numericRuleSet = RuleSet(testDF) + numericRuleSet.add(numericRules) + val (numericValidated, numericPassed) = numericRuleSet.validate() + assert(numericRules.map(_.ruleType == RuleType.ValidateNumerics).reduce(_ && _), "Not every value is validate numerics.") + assert(numericRules.map(_.boundaries == null).reduce(_ && _), "Boundaries are not null.") + assert(numericPassed) + assert(numericValidated.except(numericLovExpectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(numericValidated.filter(numericValidated("Invalid_Count") > 0).count() == 0) + assert(numericValidated.filter(numericValidated("Failed") === true).count() == 0) + + val stringRule = Rule("CheckIfProductNameInLOV", col("product_name"), Array("food_a","food_b","food_c")) + // Generate the array of Rules from the minmax generator + + val stringLovExpectedDF = Seq( + ("CheckIfProductNameInLOV","validStrings",ValidationValue(null,null,null,Array("food_a", "food_b", "food_c")),0,false) + ).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed") + + val stringRuleSet = RuleSet(testDF) + stringRuleSet.add(stringRule) + val (stringValidated, stringPassed) = stringRuleSet.validate() + assert(stringRule.ruleType == RuleType.ValidateStrings) + assert(stringRule.boundaries == null) + assert(stringPassed) + assert(stringValidated.except(stringLovExpectedDF).count() == 0, "Expected df is not equal to the returned rules report.") + assert(stringValidated.filter(stringValidated("Invalid_Count") > 0).count() == 0) + assert(stringValidated.filter(stringValidated("Failed") === true).count() == 0) + } + + +}