Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion R/pkg/R/mllib_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
#' \code{Gamma}, and \code{poisson}.
Copy link
Member Author

@felixcheung felixcheung Jan 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in R, Gamma family is capital G

#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
Expand Down Expand Up @@ -104,8 +106,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
weightCol <- ""
}

# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
"fit", formula, data@sdf, tolower(family$family), family$link,
Copy link
Contributor

@yanboliang yanboliang Jan 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put case conversion into MLlib side, then we can get consistent behavior from APIs across different languages (Scala/Python/R). I wish you would not mind that I have sent #16516 to fix it in MLlib. If it's reasonable, we can revert the change of this line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea - given how this is likely broken since Spark 1.6 I'd like to get the fix in ASAP and to master, branch-2.0, branch-2.1, or even perhaps branch-1.6 just in case.

I'd probably need to open different PRs for these since our file organization has changed, and would need to keep the R side change - unless you think ML side change could go to these branches too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need take some time to get MLlib side changed (and it involves changes for other estimators). Sounds good to keep the R side change.

tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
Expand All @@ -120,6 +123,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
#' \code{Gamma}, and \code{poisson}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param epsilon positive convergence tolerance of iterations.
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ sparkR.session <- function(
#' sparkR.session()
#' url <- sparkR.uiWebUrl()
#' }
#' @note sparkR.uiWebUrl since 2.2.0
#' @note sparkR.uiWebUrl since 2.1.1
sparkR.uiWebUrl <- function() {
sc <- sparkR.callJMethod(getSparkContext(), "sc")
u <- callJMethod(sc, "uiWebUrl")
Expand Down
26 changes: 17 additions & 9 deletions R/pkg/inst/tests/testthat/test_mllib_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,22 @@ test_that("spark.glm and predict", {

# poisson family
model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
family = poisson(link = identity))
family = poisson(link = identity))
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
data = iris, family = poisson(link = identity)), iris))
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)

# Gamma family
x <- runif(100, -1, 1)
y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
model <- glm(y ~ x, family = Gamma, df)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to use spark.glm here, since the title of this test is: spark.glm and predict(see L52). We have separate tests for R-compliant method glm, but it's not necessary to duplicate all tests.

Copy link
Member Author

@felixcheung felixcheung Jan 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it matters much - as you can see in the code glm is a single-line wrapper for spark.glm - I actually thought it was better to add some tests for glm instead of just testing spark.glm

out <- capture.output(print(summary(model)))
expect_true(any(grepl("Dispersion parameter for gamma family", out)))

# Test stats::predict is working
x <- rnorm(15)
y <- x + rnorm(15)
Expand Down Expand Up @@ -103,11 +111,11 @@ test_that("spark.glm summary", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
family = binomial(link = "logit")))
family = binomial(link = "logit")))

rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
family = binomial(link = "logit")))

coefs <- unlist(stats$coefficients)
rCoefs <- unlist(rStats$coefficients)
Expand Down Expand Up @@ -222,7 +230,7 @@ test_that("glm and predict", {
training <- suppressWarnings(createDataFrame(iris))
# gaussian family
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
prediction <- predict(model, training)
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
Expand All @@ -235,7 +243,7 @@ test_that("glm and predict", {
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
data = iris, family = poisson(link = identity)), iris))
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)

# Test stats::predict is working
Expand Down Expand Up @@ -268,11 +276,11 @@ test_that("glm summary", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
family = binomial(link = "logit")))
family = binomial(link = "logit")))

rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
family = binomial(link = "logit")))

coefs <- unlist(stats$coefficients)
rCoefs <- unlist(rStats$coefficients)
Expand Down Expand Up @@ -409,7 +417,7 @@ test_that("spark.survreg", {
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
expect_error(
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
NA)
NA)
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
}
})
Expand Down