Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions R/pkg/R/mllib_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this on "features" or "labels"? it seems it's only set on RFormula.terms which are features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the labels means the string label of a feature, which is categorical (e.g., white, black).

#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "error" the default behavior? since we are doing handleInvalid = c("error", "skip", "keep") could you add text to say it's the default. also consider sorting the options - default first "error", then "keep", "skip" if we are alphabetical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. error is the default behavior. The backend code has setDefault. I will reorder it and add the text in the document.

#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
Expand Down Expand Up @@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
maxMemoryInMB = 256, cacheNodeIds = FALSE,
handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
Expand All @@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
new("RandomForestRegressionModel", jobj = jobj)
},
classification = {
handleInvalid <- match.arg(handleInvalid)
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
Expand All @@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
as.numeric(minInfoGain), as.integer(checkpointInterval),
as.character(featureSubsetStrategy), seed,
as.numeric(subsamplingRate),
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
handleInvalid)
new("RandomForestClassificationModel", jobj = jobj)
}
)
Expand Down
17 changes: 17 additions & 0 deletions R/pkg/tests/fulltests/test_mllib_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,23 @@ test_that("spark.randomForest", {
expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50)

# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
Copy link
Member

@felixcheung felixcheung Jul 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the error string to match with expect_error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, this is a bit strange - so the spark.randomForest call and predict runs successfully, only fails when the predictions are collected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The training call has no error because it has no unseen label.

I think the internal has logic handling unseen label but when doing collection (action), it can't map the internal value to the unseen label. That is the reason why it only fails when doing collection.

I will add the error string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The console prints out :

Error in handleErrors(returnStatus, conn) :
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 13.0 failed 1 times, most recent failure: Lost task 0.0 in stage 13.0 (TID 13, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$9: (string) => double)

Shall I match this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, it looks like the task has failed

is this the proper or expected behavior on the ML side? it seems odd the error is not reported but instead the column is given a "wrong type"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check how "error" option is handled. It seems that there is no exception thrown out.

Copy link
Contributor Author

@wangmiao1981 wangmiao1981 Jul 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Scala side, I created a case where unseen label is used in test data:
`val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, "d"))
val df = data.toDF("id", "label")
val df2 = data2.toDF("id", "label")

val indexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("labelIndex")

indexer.setHandleInvalid("error")
indexer.fit(df).transform(df2).collect()

`
It also fails with same error message as R case. I think it is the expected behavior for "error".

Failed Messages:

Failed to execute user defined function($anonfun$9: (string) => double)
org.apache.spark.SparkException: Failed to execute user defined function($anonfun$9: (string) => double)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1075)
at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:139)
at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:48)
at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:30)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)

model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10,
handleInvalid = "skip")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")

# spark.randomForest classification can work on libsvm data
if (windows_with_hadoop()) {
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
Expand Down
25 changes: 25 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
def getFormula: String = $(formula)

/**
* Param for how to handle invalid data (unseen labels or NULL values).
* Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
* @group param
*/
@Since("2.3.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
"invalid data (unseen labels or NULL values). " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)

/** @group setParam */
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/** @group getParam */
@Since("2.3.0")
def getHandleInvalid: String = $(handleInvalid)

/** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down Expand Up @@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
.setInputCol(term)
.setOutputCol(indexCol)
.setStringOrderType($(stringIndexerOrderType))
.setHandleInvalid($(handleInvalid))
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
seed: String,
subsamplingRate: Double,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
cacheNodeIds: Boolean,
handleInvalid: String): RandomForestClassifierWrapper = {

val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
Expand Down