-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15339] [ML] ML 2.0 QA: Scala APIs and code audit for regression #13129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
254313c
645f6c4
374e610
d38b1eb
1fbd1dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |
| with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol | ||
| with HasSolver with Logging { | ||
|
|
||
| import GeneralizedLinearRegression._ | ||
|
|
||
| /** | ||
| * Param for the name of family which is a description of the error distribution | ||
| * to be used in the model. | ||
|
|
@@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |
| @Since("2.0.0") | ||
| final val family: Param[String] = new Param(this, "family", | ||
| "The name of family which is a description of the error distribution to be used in the " + | ||
| "model. Supported options: gaussian(default), binomial, poisson and gamma.", | ||
| ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) | ||
| s"model. Supported options: ${supportedFamilyNames.mkString(", ")}. (Default is 'gaussian')", | ||
| ParamValidators.inArray[String](supportedFamilyNames.toArray)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.0.0") | ||
|
|
@@ -71,24 +73,22 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |
| @Since("2.0.0") | ||
| final val link: Param[String] = new Param(this, "link", "The name of link function " + | ||
| "which provides the relationship between the linear predictor and the mean of the " + | ||
| "distribution function. Supported options: identity, log, inverse, logit, probit, " + | ||
| "cloglog and sqrt.", | ||
| ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) | ||
| s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", | ||
| ParamValidators.inArray[String](supportedLinkNames.toArray)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.0.0") | ||
| def getLink: String = $(link) | ||
|
|
||
| /** | ||
| * Param for link prediction (linear predictor) column name. | ||
| * Default is empty, which means we do not output link prediction. | ||
| * Default is not set, which means we do not output link prediction. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.0.0") | ||
| final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol", | ||
| "link prediction (linear predictor) column name") | ||
| setDefault(linkPredictionCol, "") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.0.0") | ||
|
|
@@ -107,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |
| s"with ${$(family)} family does not support ${$(link)} link function.") | ||
| } | ||
| val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) | ||
| if ($(linkPredictionCol).nonEmpty) { | ||
| if (isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty) { | ||
|
||
| SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) | ||
| } else { | ||
| newSchema | ||
|
|
@@ -205,7 +205,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * If this is not set or empty, we treat all instance weights as 1.0. | ||
| * Default is empty, so all instances have weight one. | ||
| * Default is not set, so all instances have weight one. | ||
| * | ||
| * @group setParam | ||
| */ | ||
|
|
@@ -214,7 +214,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |
|
|
||
| /** | ||
| * Sets the solver algorithm used for optimization. | ||
| * Currently only support "irls" which is also the default solver. | ||
| * Currently only supports "irls" which is also the default solver. | ||
| * | ||
| * @group setParam | ||
| */ | ||
|
|
@@ -239,10 +239,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |
| } | ||
| val familyAndLink = new FamilyAndLink(familyObj, linkObj) | ||
|
|
||
| val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd | ||
| .map { case Row(features: Vector) => | ||
| features.size | ||
| }.first() | ||
| val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we not do
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You means
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like Spark does not provide encoder for Vector. If I change to use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah right, we would need to add an implicit encoder However, let's leave that change for #12718 |
||
| if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { | ||
| val msg = "Currently, GeneralizedLinearRegression only supports number of features" + | ||
| s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." | ||
|
|
@@ -294,25 +291,25 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def load(path: String): GeneralizedLinearRegression = super.load(path) | ||
|
|
||
| /** Set of family and link pairs that GeneralizedLinearRegression supports. */ | ||
| private[ml] lazy val supportedFamilyAndLinkPairs = Set( | ||
| private[regression] lazy val supportedFamilyAndLinkPairs = Set( | ||
| Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, | ||
| Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, | ||
| Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, | ||
| Gamma -> Inverse, Gamma -> Identity, Gamma -> Log | ||
| ) | ||
|
|
||
| /** Set of family names that GeneralizedLinearRegression supports. */ | ||
| private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) | ||
| private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) | ||
|
|
||
| /** Set of link names that GeneralizedLinearRegression supports. */ | ||
| private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) | ||
| private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) | ||
|
|
||
| private[ml] val epsilon: Double = 1E-16 | ||
| private[regression] val epsilon: Double = 1E-16 | ||
|
|
||
| /** | ||
| * Wrapper of family and link combination used in the model. | ||
| */ | ||
| private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { | ||
| private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { | ||
|
|
||
| /** Linear predictor based on given mu. */ | ||
| def predict(mu: Double): Double = link.link(family.project(mu)) | ||
|
|
@@ -359,7 +356,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| * | ||
| * @param name the name of the family. | ||
| */ | ||
| private[ml] abstract class Family(val name: String) extends Serializable { | ||
| private[regression] abstract class Family(val name: String) extends Serializable { | ||
|
|
||
| /** The default link instance of this family. */ | ||
| val defaultLink: Link | ||
|
|
@@ -391,7 +388,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| def project(mu: Double): Double = mu | ||
| } | ||
|
|
||
| private[ml] object Family { | ||
| private[regression] object Family { | ||
|
|
||
| /** | ||
| * Gets the [[Family]] object from its name. | ||
|
|
@@ -412,7 +409,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| * Gaussian exponential family distribution. | ||
| * The default link for the Gaussian family is the identity link. | ||
| */ | ||
| private[ml] object Gaussian extends Family("gaussian") { | ||
| private[regression] object Gaussian extends Family("gaussian") { | ||
|
|
||
| val defaultLink: Link = Identity | ||
|
|
||
|
|
@@ -448,7 +445,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| * Binomial exponential family distribution. | ||
| * The default link for the Binomial family is the logit link. | ||
| */ | ||
| private[ml] object Binomial extends Family("binomial") { | ||
| private[regression] object Binomial extends Family("binomial") { | ||
|
|
||
| val defaultLink: Link = Logit | ||
|
|
||
|
|
@@ -492,7 +489,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| * Poisson exponential family distribution. | ||
| * The default link for the Poisson family is the log link. | ||
| */ | ||
| private[ml] object Poisson extends Family("poisson") { | ||
| private[regression] object Poisson extends Family("poisson") { | ||
|
|
||
| val defaultLink: Link = Log | ||
|
|
||
|
|
@@ -533,7 +530,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| * Gamma exponential family distribution. | ||
| * The default link for the Gamma family is the inverse link. | ||
| */ | ||
| private[ml] object Gamma extends Family("gamma") { | ||
| private[regression] object Gamma extends Family("gamma") { | ||
|
|
||
| val defaultLink: Link = Inverse | ||
|
|
||
|
|
@@ -578,7 +575,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| * | ||
| * @param name the name of link function. | ||
| */ | ||
| private[ml] abstract class Link(val name: String) extends Serializable { | ||
| private[regression] abstract class Link(val name: String) extends Serializable { | ||
|
|
||
| /** The link function. */ | ||
| def link(mu: Double): Double | ||
|
|
@@ -590,7 +587,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| def unlink(eta: Double): Double | ||
| } | ||
|
|
||
| private[ml] object Link { | ||
| private[regression] object Link { | ||
|
|
||
| /** | ||
| * Gets the [[Link]] object from its name. | ||
|
|
@@ -611,7 +608,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| } | ||
| } | ||
|
|
||
| private[ml] object Identity extends Link("identity") { | ||
| private[regression] object Identity extends Link("identity") { | ||
|
|
||
| override def link(mu: Double): Double = mu | ||
|
|
||
|
|
@@ -620,7 +617,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def unlink(eta: Double): Double = eta | ||
| } | ||
|
|
||
| private[ml] object Logit extends Link("logit") { | ||
| private[regression] object Logit extends Link("logit") { | ||
|
|
||
| override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) | ||
|
|
||
|
|
@@ -629,7 +626,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) | ||
| } | ||
|
|
||
| private[ml] object Log extends Link("log") { | ||
| private[regression] object Log extends Link("log") { | ||
|
|
||
| override def link(mu: Double): Double = math.log(mu) | ||
|
|
||
|
|
@@ -638,7 +635,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def unlink(eta: Double): Double = math.exp(eta) | ||
| } | ||
|
|
||
| private[ml] object Inverse extends Link("inverse") { | ||
| private[regression] object Inverse extends Link("inverse") { | ||
|
|
||
| override def link(mu: Double): Double = 1.0 / mu | ||
|
|
||
|
|
@@ -647,7 +644,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def unlink(eta: Double): Double = 1.0 / eta | ||
| } | ||
|
|
||
| private[ml] object Probit extends Link("probit") { | ||
| private[regression] object Probit extends Link("probit") { | ||
|
|
||
| override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) | ||
|
|
||
|
|
@@ -658,7 +655,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) | ||
| } | ||
|
|
||
| private[ml] object CLogLog extends Link("cloglog") { | ||
| private[regression] object CLogLog extends Link("cloglog") { | ||
|
|
||
| override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) | ||
|
|
||
|
|
@@ -667,7 +664,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) | ||
| } | ||
|
|
||
| private[ml] object Sqrt extends Link("sqrt") { | ||
| private[regression] object Sqrt extends Link("sqrt") { | ||
|
|
||
| override def link(mu: Double): Double = math.sqrt(mu) | ||
|
|
||
|
|
@@ -732,7 +729,7 @@ class GeneralizedLinearRegressionModel private[ml] ( | |
| if ($(predictionCol).nonEmpty) { | ||
| output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
| } | ||
| if ($(linkPredictionCol).nonEmpty) { | ||
| if (isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty) { | ||
| output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) | ||
| } | ||
| output.toDF() | ||
|
|
@@ -853,7 +850,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( | |
| */ | ||
| @Since("2.0.0") | ||
| val predictionCol: String = { | ||
| if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") { | ||
| if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) { | ||
| origModel.getPredictionCol | ||
| } else { | ||
| "prediction_" + java.util.UUID.randomUUID.toString | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures | |
| setDefault(isotonic -> true, featureIndex -> 0) | ||
|
|
||
| /** Checks whether the input has weight column. */ | ||
| protected[ml] def hasWeightCol: Boolean = { | ||
| isDefined(weightCol) && $(weightCol) != "" | ||
| protected def hasWeightCol: Boolean = { | ||
|
||
| isDefined(weightCol) && $(weightCol).nonEmpty | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,9 +159,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
|
|
||
| override protected def train(dataset: Dataset[_]): LinearRegressionModel = { | ||
| // Extract the number of features before deciding optimization solver. | ||
| val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { | ||
| case Row(features: Vector) => features.size | ||
| }.first() | ||
| val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, can we do |
||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
|
|
||
| if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && | ||
|
|
@@ -240,7 +238,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| val coefficients = Vectors.sparse(numFeatures, Seq()) | ||
| val intercept = yMean | ||
|
|
||
| val model = new LinearRegressionModel(uid, coefficients, intercept) | ||
| val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) | ||
| // Handle possible missing or invalid prediction columns | ||
| val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() | ||
|
|
||
|
|
@@ -252,7 +250,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| model, | ||
| Array(0D), | ||
| Array(0D)) | ||
| return copyValues(model.setSummary(trainingSummary)) | ||
| return model.setSummary(trainingSummary) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a minor bug of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So test cases didn't pick this up? We should look into why and amend the tests accordingly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is due to we don't have excellent test coverage ...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick I added test case for this scenario and updated other test cases to ensure coping prediction column(and other params) correct in all situations.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to setParent
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley It does not necessary to |
||
| } else { | ||
| require($(regParam) == 0.0, "The standard deviation of the label is zero. " + | ||
| "Model cannot be regularized.") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably meant to be
private[regression]