Skip to content

Commit 44563a0

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV
### What changes were proposed in this pull request? There were a lot of works on improving ALS's recommendForAll For now, I found that it maybe futhermore optimized by 1, using GEMV and sharing a pre-allocated buffer per task; 2, using guava.ordering instead of BoundedPriorityQueue; ### Why are the changes needed? In my test, using `f2jBLAS.sgemv`, it is about 2.3X faster than existing impl. |Impl| Master | GEMM | GEMV | GEMV + array aggregator | GEMV + guava ordering + array aggregator | GEMV + guava ordering| |------|----------|------------|----------|------------|------------|------------| |Duration|341229|363741|191201|189790|148417|147222| ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #30468 from zhengruifeng/als_rec_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
1 parent de234ee commit 44563a0

1 file changed

Lines changed: 33 additions & 20 deletions

File tree

  • mllib/src/main/scala/org/apache/spark/ml/recommendation

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.util.{Sorting, Try}
2727
import scala.util.hashing.byteswap64
2828

2929
import com.github.fommil.netlib.BLAS.{getInstance => blas}
30+
import com.google.common.collect.{Ordering => GuavaOrdering}
3031
import org.apache.hadoop.fs.Path
3132
import org.json4s.DefaultFormats
3233
import org.json4s.JsonDSL._
@@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
4748
import org.apache.spark.sql.functions._
4849
import org.apache.spark.sql.types._
4950
import org.apache.spark.storage.StorageLevel
50-
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
51+
import org.apache.spark.util.Utils
5152
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
5253
import org.apache.spark.util.random.XORShiftRandom
5354

@@ -456,30 +457,39 @@ class ALSModel private[ml] (
456457
num: Int,
457458
blockSize: Int): DataFrame = {
458459
import srcFactors.sparkSession.implicits._
460+
import scala.collection.JavaConverters._
459461

460462
val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
461463
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
462464
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
463-
.as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
464-
.flatMap { case (srcIter, dstIter) =>
465-
val m = srcIter.size
466-
val n = math.min(dstIter.size, num)
467-
val output = new Array[(Int, Int, Float)](m * n)
468-
var i = 0
469-
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
470-
srcIter.foreach { case (srcId, srcFactor) =>
471-
dstIter.foreach { case (dstId, dstFactor) =>
472-
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
473-
val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
474-
pq += dstId -> score
465+
.as[(Array[Int], Array[Float], Array[Int], Array[Float])]
466+
.mapPartitions { iter =>
467+
var scores: Array[Float] = null
468+
var idxOrd: GuavaOrdering[Int] = null
469+
iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
470+
require(srcMat.length == srcIds.length * rank)
471+
require(dstMat.length == dstIds.length * rank)
472+
val m = srcIds.length
473+
val n = dstIds.length
474+
if (scores == null || scores.length < n) {
475+
scores = Array.ofDim[Float](n)
476+
idxOrd = new GuavaOrdering[Int] {
477+
override def compare(left: Int, right: Int): Int = {
478+
Ordering[Float].compare(scores(left), scores(right))
479+
}
480+
}
475481
}
476-
pq.foreach { case (dstId, score) =>
477-
output(i) = (srcId, dstId, score)
478-
i += 1
482+
483+
Iterator.range(0, m).flatMap { i =>
484+
// buffer = i-th vec in srcMat * dstMat
485+
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
486+
srcMat, i * rank, 1, 0.0F, scores, 0, 1)
487+
488+
val srcId = srcIds(i)
489+
idxOrd.greatestOf(Iterator.range(0, n).asJava, num).asScala
490+
.iterator.map { j => (srcId, dstIds(j), scores(j)) }
479491
}
480-
pq.clear()
481492
}
482-
output.toSeq
483493
}
484494
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
485495
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
@@ -499,9 +509,12 @@ class ALSModel private[ml] (
499509
*/
500510
private def blockify(
501511
factors: Dataset[(Int, Array[Float])],
502-
blockSize: Int): Dataset[Seq[(Int, Array[Float])]] = {
512+
blockSize: Int): Dataset[(Array[Int], Array[Float])] = {
503513
import factors.sparkSession.implicits._
504-
factors.mapPartitions(_.grouped(blockSize))
514+
factors.mapPartitions { iter =>
515+
iter.grouped(blockSize)
516+
.map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray))
517+
}
505518
}
506519

507520
}

0 commit comments

Comments
 (0)