-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17748][FOLLOW-UP][ML] Reorg variables of WeightedLeastSquares. #15621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares( | |
| summary.validate() | ||
| logInfo(s"Number of instances: ${summary.count}.") | ||
| val k = if (fitIntercept) summary.k + 1 else summary.k | ||
| val numFeatures = summary.k | ||
| val triK = summary.triK | ||
| val wSum = summary.wSum | ||
| val bBar = summary.bBar | ||
| val bbBar = summary.bbBar | ||
| val aBar = summary.aBar | ||
| val aStd = summary.aStd | ||
| val abBar = summary.abBar | ||
| val aaBar = summary.aaBar | ||
| val numFeatures = abBar.size | ||
|
|
||
| val rawBStd = summary.bStd | ||
| val rawBBar = summary.bBar | ||
| // if b is constant (rawBStd is zero), then b cannot be scaled. In this case | ||
| // setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm. | ||
| val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd | ||
| // setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm. | ||
| val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd | ||
|
|
||
| if (rawBStd == 0) { | ||
| if (fitIntercept || bBar == 0.0) { | ||
| if (bBar == 0.0) { | ||
| if (fitIntercept || rawBBar == 0.0) { | ||
| if (rawBBar == 0.0) { | ||
| logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + | ||
| s"and the intercept will all be zero; as a result, training is not needed.") | ||
| } else { | ||
|
|
@@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares( | |
| s"training is not needed.") | ||
| } | ||
| val coefficients = new DenseVector(Array.ofDim(numFeatures)) | ||
| val intercept = bBar | ||
| val intercept = rawBBar | ||
| val diagInvAtWA = new DenseVector(Array(0D)) | ||
| return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D)) | ||
| } else { | ||
|
|
@@ -137,61 +133,68 @@ private[ml] class WeightedLeastSquares( | |
| } | ||
| } | ||
|
|
||
| // scale aBar to standardized space in-place | ||
| val aBarValues = aBar.values | ||
| var j = 0 | ||
| while (j < numFeatures) { | ||
| if (aStd(j) == 0.0) { | ||
| aBarValues(j) = 0.0 | ||
| } else { | ||
| aBarValues(j) /= aStd(j) | ||
| } | ||
| j += 1 | ||
| } | ||
| val bBar = summary.bBar / bStd | ||
| val bbBar = summary.bbBar / (bStd * bStd) | ||
|
|
||
| // scale abBar to standardized space in-place | ||
| val abBarValues = abBar.values | ||
| val aStdValues = aStd.values | ||
| j = 0 | ||
| while (j < numFeatures) { | ||
| if (aStdValues(j) == 0.0) { | ||
| abBarValues(j) = 0.0 | ||
| } else { | ||
| abBarValues(j) /= (aStdValues(j) * bStd) | ||
| val aStd = summary.aStd | ||
| val aBar = { | ||
| val _aBar = summary.aBar | ||
| var i = 0 | ||
| // scale aBar to standardized space in-place | ||
| while (i < numFeatures) { | ||
| if (aStd(i) == 0.0) { | ||
| _aBar.values(i) = 0.0 | ||
|
||
| } else { | ||
| _aBar.values(i) /= aStd(i) | ||
| } | ||
| i += 1 | ||
| } | ||
| j += 1 | ||
| _aBar | ||
| } | ||
|
|
||
| // scale aaBar to standardized space in-place | ||
| val aaBarValues = aaBar.values | ||
| j = 0 | ||
| var p = 0 | ||
| while (j < numFeatures) { | ||
| val aStdJ = aStdValues(j) | ||
| val abBar = { | ||
| val _abBar = summary.abBar | ||
| var i = 0 | ||
| while (i <= j) { | ||
| val aStdI = aStdValues(i) | ||
| if (aStdJ == 0.0 || aStdI == 0.0) { | ||
| aaBarValues(p) = 0.0 | ||
| // scale abBar to standardized space in-place | ||
| while (i < numFeatures) { | ||
| if (aStd(i) == 0.0) { | ||
|
||
| _abBar.values(i) = 0.0 | ||
| } else { | ||
| aaBarValues(p) /= (aStdI * aStdJ) | ||
| _abBar.values(i) /= (aStd(i) * bStd) | ||
| } | ||
| p += 1 | ||
| i += 1 | ||
| } | ||
| j += 1 | ||
| _abBar | ||
| } | ||
| val aaBar = { | ||
| val _aaBar = summary.aaBar | ||
| var j = 0 | ||
| var p = 0 | ||
| // scale aaBar to standardized space in-place | ||
| while (j < numFeatures) { | ||
| val aStdJ = aStd.values(j) | ||
| var i = 0 | ||
| while (i <= j) { | ||
| val aStdI = aStd.values(i) | ||
| if (aStdJ == 0.0 || aStdI == 0.0) { | ||
| _aaBar.values(p) = 0.0 | ||
| } else { | ||
| _aaBar.values(p) /= (aStdI * aStdJ) | ||
| } | ||
| p += 1 | ||
| i += 1 | ||
| } | ||
| j += 1 | ||
| } | ||
| _aaBar | ||
| } | ||
|
|
||
| val bBarStd = bBar / bStd | ||
| val bbBarStd = bbBar / (bStd * bStd) | ||
|
|
||
| val effectiveRegParam = regParam / bStd | ||
| val effectiveL1RegParam = elasticNetParam * effectiveRegParam | ||
| val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam | ||
|
|
||
| // add L2 regularization to diagonals | ||
| var i = 0 | ||
| j = 2 | ||
| var j = 2 | ||
| while (i < triK) { | ||
| var lambda = effectiveL2RegParam | ||
| if (!standardizeFeatures) { | ||
|
|
@@ -205,12 +208,13 @@ private[ml] class WeightedLeastSquares( | |
| if (!standardizeLabel) { | ||
| lambda *= bStd | ||
| } | ||
| aaBarValues(i) += lambda | ||
| aaBar.values(i) += lambda | ||
| i += j | ||
| j += 1 | ||
| } | ||
|
|
||
| val aa = getAtA(aaBar.values, aBar.values) | ||
| val ab = getAtB(abBar.values, bBarStd) | ||
| val ab = getAtB(abBar.values, bBar) | ||
|
|
||
| val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 && | ||
| regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) { | ||
|
|
@@ -222,7 +226,7 @@ private[ml] class WeightedLeastSquares( | |
| if (standardizeFeatures) { | ||
| effectiveL1RegParam | ||
| } else { | ||
| if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0 | ||
| if (aStd.values(index) != 0.0) effectiveL1RegParam / aStd.values(index) else 0.0 | ||
| } | ||
| } | ||
| }) | ||
|
|
@@ -237,22 +241,23 @@ private[ml] class WeightedLeastSquares( | |
| val solution = solver match { | ||
| case cholesky: CholeskySolver => | ||
| try { | ||
| cholesky.solve(bBarStd, bbBarStd, ab, aa, aBar) | ||
| cholesky.solve(bBar, bbBar, ab, aa, aBar) | ||
| } catch { | ||
| // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to | ||
| // quasi-newton solver | ||
| // Quasi-Newton solver. | ||
| case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => | ||
| logWarning("Cholesky solver failed due to singular covariance matrix. " + | ||
| "Retrying with Quasi-Newton solver.") | ||
| // ab and aa were modified in place, so reconstruct them | ||
| val _aa = getAtA(aaBar.values, aBar.values) | ||
| val _ab = getAtB(abBar.values, bBarStd) | ||
| val _ab = getAtB(abBar.values, bBar) | ||
| val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None) | ||
| newSolver.solve(bBarStd, bbBarStd, _ab, _aa, aBar) | ||
| newSolver.solve(bBar, bbBar, _ab, _aa, aBar) | ||
| } | ||
| case qn: QuasiNewtonSolver => | ||
| qn.solve(bBarStd, bbBarStd, ab, aa, aBar) | ||
| qn.solve(bBar, bbBar, ab, aa, aBar) | ||
| } | ||
|
|
||
| val (coefficientArray, intercept) = if (fitIntercept) { | ||
| (solution.coefficients.slice(0, solution.coefficients.length - 1), | ||
| solution.coefficients.last * bStd) | ||
|
|
@@ -264,14 +269,18 @@ private[ml] class WeightedLeastSquares( | |
| var q = 0 | ||
| val len = coefficientArray.length | ||
| while (q < len) { | ||
| coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 } | ||
| coefficientArray(q) *= { if (aStd.values(q) != 0.0) bStd / aStd.values(q) else 0.0 } | ||
| q += 1 | ||
| } | ||
|
|
||
| // aaInv is a packed upper triangular matrix, here we get all elements on diagonal | ||
| val diagInvAtWA = solution.aaInv.map { inv => | ||
| new DenseVector((1 to k).map { i => | ||
| val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1) * aStdValues(i - 1) | ||
| val multiplier = if (i == k && fitIntercept) { | ||
| 1.0 | ||
| } else { | ||
| aStd.values(i - 1) * aStd.values(i - 1) | ||
| } | ||
| inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier) | ||
| }.toArray) | ||
| }.getOrElse(new DenseVector(Array(0D))) | ||
|
|
@@ -280,7 +289,7 @@ private[ml] class WeightedLeastSquares( | |
| solution.objectiveHistory.getOrElse(Array(0D))) | ||
| } | ||
|
|
||
| /** Construct A^T^ A from summary statistics. */ | ||
| /** Construct A^T^ A (append bias if necessary). */ | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Summary statistics are in original space, here we construct A^T^ A from standardized space. |
||
| private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = { | ||
| if (fitIntercept) { | ||
| new DenseVector(Array.concat(aaBar, aBar, Array(1.0))) | ||
|
|
@@ -289,7 +298,7 @@ private[ml] class WeightedLeastSquares( | |
| } | ||
| } | ||
|
|
||
| /** Construct A^T^ b from summary statistics. */ | ||
| /** Construct A^T^ b (append bias if necessary). */ | ||
| private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = { | ||
| if (fitIntercept) { | ||
| new DenseVector(Array.concat(abBar, Array(bBar))) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the names
bBarStdandbbBarStdhere, as I think they're more descriptive. But it is not a strong preference so I'll leave it up to you.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we use
bBar, bbBar, aBar, abBar, aaBarin standardized space always, so I did not appendStdas suffix for all variables. If we only add suffix forbBarandbbBar, developers may misinterpret that other variables are not in standardized space.