-
Notifications
You must be signed in to change notification settings - Fork 29k
[SparkR][SPARK-20307]:SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer #18496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a2cdf51
fa7bd4b
116d996
042cfbf
a2619c2
c71608d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara | |
| #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching | ||
| #' can speed up training of deeper trees. Users can set how often should the | ||
| #' cache be checkpointed or disable it by setting checkpointInterval. | ||
| #' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. | ||
| #' Supported options: "skip" (filter out rows with invalid data), | ||
| #' "error" (throw an error), "keep" (put invalid data in a special additional | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is "error" the default behavior? since we are doing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
| #' bucket, at index numLabels). Default is "error". | ||
| #' @param ... additional arguments passed to the method. | ||
| #' @aliases spark.randomForest,SparkDataFrame,formula-method | ||
| #' @return \code{spark.randomForest} returns a fitted Random Forest model. | ||
|
|
@@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo | |
| maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, | ||
| featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, | ||
| minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, | ||
| maxMemoryInMB = 256, cacheNodeIds = FALSE) { | ||
| maxMemoryInMB = 256, cacheNodeIds = FALSE, | ||
| handleInvalid = c("error", "keep", "skip")) { | ||
| type <- match.arg(type) | ||
| formula <- paste(deparse(formula), collapse = "") | ||
| if (!is.null(seed)) { | ||
|
|
@@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo | |
| new("RandomForestRegressionModel", jobj = jobj) | ||
| }, | ||
| classification = { | ||
| handleInvalid <- match.arg(handleInvalid) | ||
| if (is.null(impurity)) impurity <- "gini" | ||
| impurity <- match.arg(impurity, c("gini", "entropy")) | ||
| jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", | ||
|
|
@@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo | |
| as.numeric(minInfoGain), as.integer(checkpointInterval), | ||
| as.character(featureSubsetStrategy), seed, | ||
| as.numeric(subsamplingRate), | ||
| as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) | ||
| as.integer(maxMemoryInMB), as.logical(cacheNodeIds), | ||
| handleInvalid) | ||
| new("RandomForestClassificationModel", jobj = jobj) | ||
| } | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -212,6 +212,23 @@ test_that("spark.randomForest", { | |
| expect_equal(length(grep("1.0", predictions)), 50) | ||
| expect_equal(length(grep("2.0", predictions)), 50) | ||
|
|
||
| # Test unseen labels | ||
| data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), | ||
| someString = base::sample(c("this", "that"), 10, replace = TRUE), | ||
| stringsAsFactors = FALSE) | ||
| trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) | ||
| traindf <- as.DataFrame(data[trainidxs, ]) | ||
| testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) | ||
| model <- spark.randomForest(traindf, clicked ~ ., type = "classification", | ||
| maxDepth = 10, maxBins = 10, numTrees = 10) | ||
| predictions <- predict(model, testdf) | ||
| expect_error(collect(predictions)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add the error string to match with expect_error
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, this is a bit strange - so the spark.randomForest call and predict runs successfully, only fails when the predictions are collected?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The training call has no error because it has no unseen label. I think the internal has logic handling unseen label but when doing collection (action), it can't map the internal value to the unseen label. That is the reason why it only fails when doing collection. I will add the error string.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The console prints out : Error in handleErrors(returnStatus, conn) : Shall I match this?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, it looks like the task has failed is this the proper or expected behavior on the ML side? it seems odd the error is not reported but instead the column is given a "wrong type"
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me check how "error" option is handled. It seems that there is no exception thrown out.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On Scala side, I created a case where unseen label is used in test data: ` Failed Messages: Failed to execute user defined function($anonfun$9: (string) => double) |
||
| model <- spark.randomForest(traindf, clicked ~ ., type = "classification", | ||
| maxDepth = 10, maxBins = 10, numTrees = 10, | ||
| handleInvalid = "skip") | ||
| predictions <- predict(model, testdf) | ||
| expect_equal(class(collect(predictions)$clicked[1]), "character") | ||
|
|
||
| # spark.randomForest classification can work on libsvm data | ||
| if (windows_with_hadoop()) { | ||
| data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this on "features" or "labels"? it seems it's only set on RFormula.terms which are features
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the
labelsmeans the string label of a feature, which is categorical (e.g.,white,black).