-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19133][SPARKR][ML] fix glm for Gamma, clarify glm family supported #16511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) | |
| #' This can be a character string naming a family function, a family function or | ||
| #' the result of a call to a family function. Refer R family at | ||
| #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. | ||
| #' Currently these families are supported: \code{binomial}, \code{gaussian}, | ||
| #' \code{Gamma}, and \code{poisson}. | ||
| #' @param tol positive convergence tolerance of iterations. | ||
| #' @param maxIter integer giving the maximal number of IRLS iterations. | ||
| #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance | ||
|
|
@@ -104,8 +106,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
| weightCol <- "" | ||
| } | ||
|
|
||
| # For known families, Gamma is upper-cased | ||
| jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", | ||
| "fit", formula, data@sdf, family$family, family$link, | ||
| "fit", formula, data@sdf, tolower(family$family), family$link, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should put case conversion into MLlib side, then we can get consistent behavior from APIs across different languages (Scala/Python/R). I wish you would not mind that I have sent #16516 to fix it in MLlib. If it's reasonable, we can revert the change of this line.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good idea - given how this is likely broken since Spark 1.6 I'd like to get the fix in ASAP and to master, branch-2.0, branch-2.1, or even perhaps branch-1.6 just in case. I'd probably need to open different PRs for these since our file organization has changed, and would need to keep the R side change - unless you think ML side change could go to these branches too?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need take some time to get MLlib side changed (and it involves changes for other estimators). Sounds good to keep the R side change. |
||
| tol, as.integer(maxIter), as.character(weightCol), regParam) | ||
| new("GeneralizedLinearRegressionModel", jobj = jobj) | ||
| }) | ||
|
|
@@ -120,6 +123,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
| #' This can be a character string naming a family function, a family function or | ||
| #' the result of a call to a family function. Refer R family at | ||
| #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. | ||
| #' Currently these families are supported: \code{binomial}, \code{gaussian}, | ||
| #' \code{Gamma}, and \code{poisson}. | ||
| #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance | ||
| #' weights as 1.0. | ||
| #' @param epsilon positive convergence tolerance of iterations. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,14 +61,22 @@ test_that("spark.glm and predict", { | |
|
|
||
| # poisson family | ||
| model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, | ||
| family = poisson(link = identity)) | ||
| family = poisson(link = identity)) | ||
| prediction <- predict(model, training) | ||
| expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") | ||
| vals <- collect(select(prediction, "prediction")) | ||
| rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, | ||
| data = iris, family = poisson(link = identity)), iris)) | ||
| data = iris, family = poisson(link = identity)), iris)) | ||
| expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) | ||
|
|
||
| # Gamma family | ||
| x <- runif(100, -1, 1) | ||
| y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) | ||
| df <- as.DataFrame(as.data.frame(list(x = x, y = y))) | ||
| model <- glm(y ~ x, family = Gamma, df) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer to use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure it matters much - as you can see in the code |
||
| out <- capture.output(print(summary(model))) | ||
| expect_true(any(grepl("Dispersion parameter for gamma family", out))) | ||
|
|
||
| # Test stats::predict is working | ||
| x <- rnorm(15) | ||
| y <- x + rnorm(15) | ||
|
|
@@ -103,11 +111,11 @@ test_that("spark.glm summary", { | |
| df <- suppressWarnings(createDataFrame(iris)) | ||
| training <- df[df$Species %in% c("versicolor", "virginica"), ] | ||
| stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, | ||
| family = binomial(link = "logit"))) | ||
| family = binomial(link = "logit"))) | ||
|
|
||
| rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] | ||
| rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, | ||
| family = binomial(link = "logit"))) | ||
| family = binomial(link = "logit"))) | ||
|
|
||
| coefs <- unlist(stats$coefficients) | ||
| rCoefs <- unlist(rStats$coefficients) | ||
|
|
@@ -222,7 +230,7 @@ test_that("glm and predict", { | |
| training <- suppressWarnings(createDataFrame(iris)) | ||
| # gaussian family | ||
| model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) | ||
| prediction <- predict(model, training) | ||
| prediction <- predict(model, training) | ||
| expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") | ||
| vals <- collect(select(prediction, "prediction")) | ||
| rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) | ||
|
|
@@ -235,7 +243,7 @@ test_that("glm and predict", { | |
| expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") | ||
| vals <- collect(select(prediction, "prediction")) | ||
| rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, | ||
| data = iris, family = poisson(link = identity)), iris)) | ||
| data = iris, family = poisson(link = identity)), iris)) | ||
| expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) | ||
|
|
||
| # Test stats::predict is working | ||
|
|
@@ -268,11 +276,11 @@ test_that("glm summary", { | |
| df <- suppressWarnings(createDataFrame(iris)) | ||
| training <- df[df$Species %in% c("versicolor", "virginica"), ] | ||
| stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, | ||
| family = binomial(link = "logit"))) | ||
| family = binomial(link = "logit"))) | ||
|
|
||
| rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] | ||
| rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, | ||
| family = binomial(link = "logit"))) | ||
| family = binomial(link = "logit"))) | ||
|
|
||
| coefs <- unlist(stats$coefficients) | ||
| rCoefs <- unlist(rStats$coefficients) | ||
|
|
@@ -409,7 +417,7 @@ test_that("spark.survreg", { | |
| x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) | ||
| expect_error( | ||
| model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), | ||
| NA) | ||
| NA) | ||
| expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) | ||
| } | ||
| }) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in R,
Gammafamily is capitalG