-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19391][SparkR][ML] Tweedie GLM API for SparkR #16729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
67364ab
654551b
852dd6e
5aa4ae7
3682692
3555afb
56f6da0
083849c
fb66ce0
0d722fd
d11fc4b
4c24158
295711d
c315fb1
9be9c51
201939b
6737122
0b5ed43
b10777e
7d5bd60
a9ac439
f540922
6cbc62f
ef65adc
c11e57c
5ce4c84
aeeb3f7
4cffc40
0b496a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,8 @@ Suggests: | |
| rmarkdown, | ||
| testthat, | ||
| e1071, | ||
| survival | ||
| survival, | ||
| statmod | ||
| Collate: | ||
| 'schema.R' | ||
| 'generics.R' | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,6 +84,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) | |
| #' # can also read back the saved model and print | ||
| #' savedModel <- read.ml(path) | ||
| #' summary(savedModel) | ||
| #' | ||
| #' # fit tweedie model | ||
| #' require(statmod) | ||
|
||
| #' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, | ||
| #' family = tweedie(var.power = 1.2, link.power = 0)) | ||
| #' summary(model) | ||
| #' } | ||
| #' @note spark.glm since 2.0.0 | ||
| #' @seealso \link{glm}, \link{read.ml} | ||
|
|
@@ -101,6 +107,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
| stop("'family' not recognized") | ||
| } | ||
|
|
||
| # recover variancePower and linkPower from the specified tweedie family | ||
| if (tolower(family$family) == "tweedie") { | ||
| variancePower <- log(family$variance(exp(1))) | ||
| linkPower <- log(family$linkfun(exp(1))) | ||
| } else { | ||
| # these default values are not used | ||
| variancePower <- 0.0 | ||
| linkPower <- 1.0 | ||
| } | ||
|
|
||
| formula <- paste(deparse(formula), collapse = "") | ||
| if (is.null(weightCol)) { | ||
| weightCol <- "" | ||
|
|
@@ -109,7 +125,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
| # For known families, Gamma is upper-cased | ||
| jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", | ||
| "fit", formula, data@sdf, tolower(family$family), family$link, | ||
| tol, as.integer(maxIter), as.character(weightCol), regParam) | ||
| tol, as.integer(maxIter), as.character(weightCol), regParam, | ||
| as.double(variancePower), as.double(linkPower)) | ||
|
||
| new("GeneralizedLinearRegressionModel", jobj = jobj) | ||
| }) | ||
|
|
||
|
|
@@ -124,7 +141,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
| #' the result of a call to a family function. Refer R family at | ||
| #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. | ||
| #' Currently these families are supported: \code{binomial}, \code{gaussian}, | ||
| #' \code{Gamma}, and \code{poisson}. | ||
| #' \code{poisson}, \code{Gamma}, and \code{tweedie} (\code{statmod} package). | ||
| #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance | ||
| #' weights as 1.0. | ||
| #' @param epsilon positive convergence tolerance of iterations. | ||
|
|
@@ -170,9 +187,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), | |
| deviance <- callJMethod(jobj, "rDeviance") | ||
| df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") | ||
| df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") | ||
| aic <- callJMethod(jobj, "rAic") | ||
| iter <- callJMethod(jobj, "rNumIterations") | ||
| family <- callJMethod(jobj, "rFamily") | ||
| aic <- callJMethod(jobj, "rAic") | ||
| if (family == "tweedie" && aic == 0) aic <- NA | ||
| deviance.resid <- if (is.loaded) { | ||
| NULL | ||
| } else { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,6 +77,18 @@ test_that("spark.glm and predict", { | |
| out <- capture.output(print(summary(model))) | ||
| expect_true(any(grepl("Dispersion parameter for gamma family", out))) | ||
|
|
||
| # tweedie family | ||
| require(statmod) | ||
|
||
| model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, | ||
| family = tweedie(var.power = 1.2, link.power = 1.0)) | ||
| prediction <- predict(model, training) | ||
| expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you might want to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you remind me what |
||
| vals <- collect(select(prediction, "prediction")) | ||
| rVals <- suppressWarnings(predict( | ||
|
||
| glm(Sepal.Width ~ Sepal.Length + Species, data = iris, | ||
| family = tweedie(var.power = 1.2, link.power = 1.0)), iris)) | ||
| expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) | ||
|
|
||
| # Test stats::predict is working | ||
| x <- rnorm(15) | ||
| y <- x + rnorm(15) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,9 @@ private[r] object GeneralizedLinearRegressionWrapper | |
| tol: Double, | ||
| maxIter: Int, | ||
| weightCol: String, | ||
| regParam: Double): GeneralizedLinearRegressionWrapper = { | ||
| regParam: Double, | ||
| variancePower: Double, | ||
| linkPower: Double): GeneralizedLinearRegressionWrapper = { | ||
| val rFormula = new RFormula().setFormula(formula) | ||
| checkDataColumns(rFormula, data) | ||
| val rFormulaModel = rFormula.fit(data) | ||
|
|
@@ -81,15 +83,20 @@ private[r] object GeneralizedLinearRegressionWrapper | |
| .attributes.get | ||
| val features = featureAttrs.map(_.name.get) | ||
| // assemble and fit the pipeline | ||
| val glr = new GeneralizedLinearRegression() | ||
| var glr = new GeneralizedLinearRegression() | ||
| .setFamily(family) | ||
| .setLink(link) | ||
| .setFitIntercept(rFormula.hasIntercept) | ||
| .setTol(tol) | ||
| .setMaxIter(maxIter) | ||
| .setWeightCol(weightCol) | ||
| .setRegParam(regParam) | ||
| .setFeaturesCol(rFormula.getFeaturesCol) | ||
| // set variancePower and linkPower if family is tweedie; otherwise, set link function | ||
| if (family.toLowerCase == "tweedie") { | ||
| glr = glr.setVariancePower(variancePower).setLinkPower(linkPower) | ||
|
||
| } else { | ||
| glr = glr.setLink(link) | ||
| } | ||
| val pipeline = new Pipeline() | ||
| .setStages(Array(rFormulaModel, glr)) | ||
| .fit(data) | ||
|
|
@@ -143,7 +150,12 @@ private[r] object GeneralizedLinearRegressionWrapper | |
| val rDeviance: Double = summary.deviance | ||
| val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull | ||
| val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom | ||
| val rAic: Double = summary.aic | ||
| val rAic: Double = if (family.toLowerCase == "tweedie" && | ||
| !Array(0.0, 1.0, 2.0).contains(variancePower)) { | ||
|
||
| 0.0 | ||
| } else { | ||
| summary.aic | ||
| } | ||
| val rNumIterations: Int = summary.numIterations | ||
|
|
||
| new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update L56 for documentation. Also we should update the programming guide and vignettes too