Skip to content

Commit 4554529

Browse files
MechCodermengxr
authored andcommitted
[SPARK-4406] [MLib] FIX: Validate k in SVD
Raise exception when k is non-positive in SVD Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #3945 from MechCoder/spark-4406 and squashes the following commits: 64e6d2d [MechCoder] TST: Add better test errors and messages 12dae73 [MechCoder] [SPARK-4406] FIX: Validate k in SVD
1 parent 8782eb9 commit 4554529

4 files changed

Lines changed: 19 additions & 1 deletion

File tree

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ class IndexedRowMatrix(
102102
k: Int,
103103
computeU: Boolean = false,
104104
rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = {
105+
106+
val n = numCols().toInt
107+
require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
105108
val indices = rows.map(_.index)
106109
val svd = toRowMatrix().computeSVD(k, computeU, rCond)
107110
val U = if (computeU) {

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class RowMatrix(
212212
tol: Double,
213213
mode: String): SingularValueDecomposition[RowMatrix, Matrix] = {
214214
val n = numCols().toInt
215-
require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.")
215+
require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
216216

217217
object SVDMode extends Enumeration {
218218
val LocalARPACK, LocalLAPACK, DistARPACK = Value

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
113113
assert(closeToZero(U * brzDiag(s) * V.t - localA))
114114
}
115115

116+
test("validate k in svd") {
117+
val A = new IndexedRowMatrix(indexedRows)
118+
intercept[IllegalArgumentException] {
119+
A.computeSVD(-1)
120+
}
121+
}
122+
116123
def closeToZero(G: BDM[Double]): Boolean = {
117124
G.valuesIterator.map(math.abs).sum < 1e-6
118125
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
171171
}
172172
}
173173

174+
test("validate k in svd") {
175+
for (mat <- Seq(denseMat, sparseMat)) {
176+
intercept[IllegalArgumentException] {
177+
mat.computeSVD(-1)
178+
}
179+
}
180+
}
181+
174182
def closeToZero(G: BDM[Double]): Boolean = {
175183
G.valuesIterator.map(math.abs).sum < 1e-6
176184
}

0 commit comments

Comments
 (0)