Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ project/project/
project/target/
*.DS_Store
/target/
/project/
/project/build.properties
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
We happily welcome contributions to *Databricks Labs - dataframe-rules-engine*.
We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes.
We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes.
Please make a fork of this repository and submit a pull request.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,21 @@ cd Downloads
git pull repo
sbt clean package
```

## Running tests
To run tests on the project: <br>
```
sbt test
```

Make sure that your JAVA_HOME is setup for sbt to run the tests properly. You will need JDK 8 as Spark does
not support newer versions of the JDK.

## Test reports for test coverage
To get test coverage report for the project: <br>
```
sbt jacoco
```

The test reports can be found in target/scala-<version>/jacoco/

20 changes: 20 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@ scalacOptions ++= Seq("-Xmax-classfile-name", "78")

libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.0"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0"
libraryDependencies += "org.scalactic" %% "scalactic" % "3.1.1"
libraryDependencies += "org.scalatest" %% "scalatest" % "3.1.1" % "test"

lazy val excludes = jacocoExcludes in Test := Seq()

lazy val jacoco = jacocoReportSettings in test :=JacocoReportSettings(
"Jacoco Scala Example Coverage Report",
None,
JacocoThresholds (branch = 100),
Seq(JacocoReportFormats.ScalaHTML,
JacocoReportFormats.CSV),
"utf-8")

val jacocoSettings = Seq(jacoco)
lazy val jse = (project in file (".")).settings(jacocoSettings: _*)

fork in Test := true
javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled")
testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, "-oD")


lazy val commonSettings = Seq(
version := "0.1.1",
Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
addSbtPlugin("com.github.sbt" % "sbt-jacoco" % "3.0.3")
37 changes: 37 additions & 0 deletions src/main/scala/com/databricks/labs/validation/QuickTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.databricks.labs.validation

import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef}
import com.databricks.labs.validation.utils.SparkSessionWrapper
import org.apache.spark.sql.functions._

object QuickTest extends App with SparkSessionWrapper {

import spark.implicits._

val testDF = Seq(
(1, 2, 3),
(4, 5, 6),
(7, 8, 9)
).toDF("retail_price", "scan_price", "cost")

Rule("Reasonable_sku_counts", count(col("sku")), Bounds(lower = 20.0, upper = 200.0))

val minMaxPriceDefs = Array(
MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)),
MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99))
)

val someRuleSet = RuleSet(testDF)
.add(Rule("retail_pass", col("retail_price"), Bounds(lower = 1.0, upper = 7.0)))
.add(Rule("retail_agg_pass_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 7.1)))
.add(Rule("retail_agg_pass_low", min(col("retail_price")), Bounds(lower = 0.0, upper = 7.0)))
.add(Rule("retail_fail_low", col("retail_price"), Bounds(lower = 1.1, upper = 7.0)))
.add(Rule("retail_fail_high", col("retail_price"), Bounds(lower = 0.0, upper = 6.9)))
.add(Rule("retail_agg_fail_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 6.9)))
.add(Rule("retail_agg_fail_low", min(col("retail_price")), Bounds(lower = 1.1, upper = 7.0)))
.addMinMaxRules(minMaxPriceDefs: _*)
val (rulesReport, passed) = someRuleSet.validate()

testDF.show(20, false)
rulesReport.show(20, false)
}
16 changes: 8 additions & 8 deletions src/main/scala/com/databricks/labs/validation/Rule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Rule {
private var _validNumerics: Array[Double] = _
private var _validStrings: Array[String] = _
private var _dateTimeLogic: Column = _
private var _ruleType: String = _
private var _ruleType: RuleType.Value = _
private var _isAgg: Boolean = _

private def setRuleName(value: String): this.type = {
Expand Down Expand Up @@ -69,7 +69,7 @@ class Rule {
this
}

private def setRuleType(value: String): this.type = {
private def setRuleType(value: RuleType.Value): this.type = {
_ruleType = value
this
}
Expand Down Expand Up @@ -99,7 +99,7 @@ class Rule {

def dateTimeLogic: Column = _dateTimeLogic

def ruleType: String = _ruleType
def ruleType: RuleType.Value = _ruleType

private[validation] def isAgg: Boolean = _isAgg

Expand All @@ -121,7 +121,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setBoundaries(boundaries)
.setRuleType("bounds")
.setRuleType(RuleType.ValidateBounds)
.setIsAgg
}

Expand All @@ -135,7 +135,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidNumerics(validNumerics)
.setRuleType("validNumerics")
.setRuleType(RuleType.ValidateNumerics)
.setIsAgg
}

Expand All @@ -149,7 +149,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
.setRuleType("validNumerics")
.setRuleType(RuleType.ValidateNumerics)
.setIsAgg
}

Expand All @@ -163,7 +163,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
.setRuleType("validNumerics")
.setRuleType(RuleType.ValidateNumerics)
.setIsAgg
}

Expand All @@ -177,7 +177,7 @@ object Rule {
.setRuleName(ruleName)
.setColumn(column)
.setValidStrings(validStrings)
.setRuleType("validStrings")
.setRuleType(RuleType.ValidateStrings)
.setIsAgg
}

Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/com/databricks/labs/validation/RuleType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.databricks.labs.validation

/**
* Definition of the Rule Types as an Enumeration for better type matching
*/
object RuleType extends Enumeration {
val ValidateBounds = Value("bounds")
val ValidateNumerics = Value("validNumerics")
val ValidateStrings = Value("validStrings")
val ValidateDateTime = Value("validDateTime")
val ValidateComplex = Value("complex")
}
60 changes: 37 additions & 23 deletions src/main/scala/com/databricks/labs/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {

import spark.implicits._

private val boundaryRules = ruleSet.getRules.filter(_.ruleType == "bounds")
private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == "validNumerics" ||
rule.ruleType == "validStrings")
private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == "dateTime")
private val complexRules = ruleSet.getRules.filter(_.ruleType == "complex")
private val boundaryRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateBounds)
private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == RuleType.ValidateNumerics ||
rule.ruleType == RuleType.ValidateStrings)
private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateDateTime)
private val complexRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateComplex)
private val byCols = ruleSet.getGroupBys map col

/**
Expand All @@ -36,15 +36,15 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
*/
private def buildValidationsByType(rule: Rule): Column = {
val nulls = mutable.Map[String, Column](
"bounds" -> lit(null).cast(ArrayType(DoubleType)).alias("bounds"),
"validNumerics" -> lit(null).cast(ArrayType(DoubleType)).alias("validNumerics"),
"validStrings" -> lit(null).cast(ArrayType(StringType)).alias("validStrings"),
"validDate" -> lit(null).cast(LongType).alias("validDate")
RuleType.ValidateBounds.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateBounds.toString),
RuleType.ValidateNumerics.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateNumerics.toString),
RuleType.ValidateStrings.toString -> lit(null).cast(ArrayType(StringType)).alias(RuleType.ValidateStrings.toString),
RuleType.ValidateDateTime.toString -> lit(null).cast(LongType).alias(RuleType.ValidateDateTime.toString)
)
rule.ruleType match {
case "bounds" => nulls("bounds") = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias("bounds")
case "validNumerics" => nulls("validNumerics") = lit(rule.validNumerics).alias("validNumerics")
case "validStrings" => nulls("validStrings") = lit(rule.validStrings).alias("validStrings")
case RuleType.ValidateBounds => nulls(RuleType.ValidateBounds.toString) = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias(RuleType.ValidateBounds.toString)
case RuleType.ValidateNumerics => nulls(RuleType.ValidateNumerics.toString) = lit(rule.validNumerics).alias(RuleType.ValidateNumerics.toString)
case RuleType.ValidateStrings => nulls(RuleType.ValidateStrings.toString) = lit(rule.validStrings).alias(RuleType.ValidateStrings.toString)
}
val validationsByType = nulls.toMap.values.toSeq
struct(
Expand All @@ -61,7 +61,7 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
private def buildOutputStruct(rule: Rule, results: Seq[Column]): Column = {
struct(
lit(rule.ruleName).alias("Rule_Name"),
lit(rule.ruleType).alias("Rule_Type"),
lit(rule.ruleType.toString).alias("Rule_Type"),
buildValidationsByType(rule),
struct(results: _*).alias("Results")
).alias("Validation")
Expand Down Expand Up @@ -101,28 +101,42 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {

// Results must have Invalid_Count & Failed
rule.ruleType match {
case "bounds" =>
case RuleType.ValidateBounds =>
// Rule evaluation for NON-AGG RULES ONLY
val invalid = rule.inputColumn < rule.boundaries.lower || rule.inputColumn > rule.boundaries.upper
// This is the first select it must come before subsequent selects as it aliases the original column name
// to that of the rule name. ADDITIONALLY, this evaluates the boundary rule WHEN the input col is not an Agg.
// This can be confusing because for Non-agg columns it renames the column to the rule_name AND returns a 0
// or 1 (not the original value)
// IF the rule is NOT an AGG then the column is simply aliased to the rule name and no evaluation takes place
// here.
val first = if (!rule.isAgg) { // Not Agg
sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
} else { // Is Agg
rule.inputColumn.alias(rule.ruleName)
}
// WHEN RULE IS AGG -- this is where the evaluation happens. The input column was renamed to the name of the
// rule in the required previous select.
// IMPORTANT: REMEMBER - that agg expressions evaluate to a single output value thus the invalid_count in
// cases where agg is used cannot be > 1 since the sum of a single value cannot exceed 1.

// WHEN RULE NOT AGG - determine if the result of "first" select (0 or 1) is > 0, if it is, the rule has
// failed since the sum(1 or more 1s) means that 1 or more rows have failed thus the rule has failed
val failed = if (rule.isAgg) {
when(
col(rule.ruleName) < rule.boundaries.lower || col(rule.ruleName) > rule.boundaries.upper, true)
.otherwise(false).alias("Failed")
} else{
when(col(rule.ruleName) > 0,true).otherwise(false).alias("Failed")
}
val first = if (!rule.isAgg) { // Not Agg
sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
} else { // Is Agg
rule.inputColumn.alias(rule.ruleName)
}
val results = if (rule.isAgg) {
Seq(when(failed, 1).otherwise(0).cast(LongType).alias("Invalid_Count"), failed)
} else {
Seq(col(rule.ruleName).cast(LongType).alias("Invalid_Count"), failed)
}
Selects(buildOutputStruct(rule, results), first)
case x if x == "validNumerics" || x == "validStrings" =>
val invalid = if (x == "validNumerics") {
case x if x == RuleType.ValidateNumerics || x == RuleType.ValidateStrings =>
val invalid = if (x == RuleType.ValidateNumerics) {
expr(s"size(array_except(${rule.ruleName}," +
s"array(${rule.validNumerics.mkString("D,")}D)))")
} else {
Expand All @@ -134,8 +148,8 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
val first = collect_set(rule.inputColumn).alias(rule.ruleName)
val results = Seq(invalid.cast(LongType).alias("Invalid_Count"), failed)
Selects(buildOutputStruct(rule, results), first)
case "validDate" => ??? // TODO
case "complex" => ??? // TODO
case RuleType.ValidateDateTime => ??? // TODO
case RuleType.ValidateComplex => ??? // TODO
}
})
}
Expand Down
13 changes: 13 additions & 0 deletions src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Set everything to be logged to the console
log4j.rootCategory=ERROR, console
log4j.appender.console=org.apache.log4j.ConsoleAppender
log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n

# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN
log4j.logger.org.apache.spark.sql.SparkSession$Builder=ERROR
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.databricks.labs.validation

import org.apache.spark.sql.SparkSession

trait SparkSessionFixture {
lazy val spark = SparkSession
.builder()
.master("local")
.appName("spark session")
.config("spark.sql.shuffle.partitions", "1")
.getOrCreate()
}
Loading