diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index c02ba426fcc3a..5439b3acc30d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -550,7 +550,22 @@ class SparseMatrix @Since("1.3.0") ( values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) override def equals(o: Any): Boolean = o match { - case m: Matrix => toBreeze == m.toBreeze + case m : Matrix => + if (this.numRows != m.numRows || this.numCols != m.numCols) return false + if (this.numNonzeros != m.numNonzeros) return false + val activeIterator = toBreeze.activeIterator + m match { + case s: SparseMatrix => + val filterIter = s.toBreeze.activeIterator.withFilter(_._2 != 0.0) + val currFilterIter = activeIterator.withFilter(_._2 != 0.0) + filterIter.sameElements(currFilterIter) + case d: DenseMatrix => + while(activeIterator.hasNext){ + val next = activeIterator.next() + if(d.apply(next._1._1, next._1._2) != next._2) return false + } + true + } case _ => false } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index bfd6d5495f5e0..33b54813f425e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -82,6 +82,9 @@ class MatricesSuite extends SparkFunSuite { val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) assert(dm1 === dm2.transpose) + val dmz = Matrices.dense(2, 2, Array(0.0, 0.0, 0.0, 0.0)) + assert(dmz === dmz.transpose) + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse assert(sm1 === sm1) assert(sm1 === dm1) @@ -90,6 +93,31 @@ class MatricesSuite extends SparkFunSuite { val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse assert(sm1 === sm2.transpose) assert(sm1 === dm2.transpose) + + val smz = dmz.asInstanceOf[DenseMatrix].toSparse + val sm2z = dm2.asInstanceOf[DenseMatrix].toSparse + sm2z(0,1) = 0.0 + sm2z(1,0) = 0.0 + sm2z(1,1) = 0.0 + assert(smz === sm2z) + assert(smz === dmz) + assert(sm2z === dmz) + assert(sm2z === smz) + val dm3 = Matrices.dense(3, 3, Array(1.0, 0.0, 4.0, 0.0, 1.0, 0.0, 3.0, 0.0, 1.0)) + val sm3 = dm3.asInstanceOf[DenseMatrix].toSparse + assert(sm3 === sm3) + assert(sm3 === dm3) + assert(sm3 !== sm3.transpose) + + val sm3explicit = Matrices.sparse(3, 3, colPtrs = Array(0, 3, 6, 9), + rowIndices = Array(0, 1, 2, 0, 1, 2, 0, 1, 2), values = dm3.toArray) + assert(sm3 === sm3explicit) + + val dmEye = Matrices.eye(10) + val smEye = Matrices.speye(10) + assert(smEye === smEye) + assert(smEye === dmEye) + assert(smEye === smEye.transpose) } test("matrix copies are deep copies") {