Skip to content

Commit 9b63900

Browse files
committed
[BRANCH-1.2][SPARK-4604][MLLIB] make MatrixFactorizationModel public
We reverted #3459 in branch-1.2 due to missing `import o.a.s.SparkContext._`, which is no longer needed in master (#3262). This PR adds #3459 back to branch-1.2 with correct imports. Github is out-of-sync now. The real changes are the last two commits. Author: Xiangrui Meng <[email protected]> Closes #3473 from mengxr/SPARK-4604-1.2 and squashes the following commits: a7638a5 [Xiangrui Meng] add import o.a.s.SparkContext._ for v1.2 b749000 [Xiangrui Meng] [SPARK-4604][MLLIB] make MatrixFactorizationModel public
1 parent 9f3b159 commit 9b63900

2 files changed

Lines changed: 81 additions & 2 deletions

File tree

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,46 @@ import java.lang.{Integer => JavaInteger}
2121

2222
import org.jblas.DoubleMatrix
2323

24+
import org.apache.spark.Logging
2425
import org.apache.spark.SparkContext._
2526
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
2627
import org.apache.spark.rdd.RDD
28+
import org.apache.spark.storage.StorageLevel
2729

2830
/**
2931
* Model representing the result of matrix factorization.
3032
*
33+
* Note: If you create the model directly using constructor, please be aware that fast prediction
34+
* requires cached user/product features and their associated partitioners.
35+
*
3136
* @param rank Rank for the features in this model.
3237
* @param userFeatures RDD of tuples where each tuple represents the userId and
3338
* the features computed for this user.
3439
* @param productFeatures RDD of tuples where each tuple represents the productId
3540
* and the features computed for this product.
3641
*/
37-
class MatrixFactorizationModel private[mllib] (
42+
class MatrixFactorizationModel(
3843
val rank: Int,
3944
val userFeatures: RDD[(Int, Array[Double])],
40-
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
45+
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
46+
47+
require(rank > 0)
48+
validateFeatures("User", userFeatures)
49+
validateFeatures("Product", productFeatures)
50+
51+
/** Validates factors and warns users if there are performance concerns. */
52+
private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
53+
require(features.first()._2.size == rank,
54+
s"$name feature dimension does not match the rank $rank.")
55+
if (features.partitioner.isEmpty) {
56+
logWarning(s"$name factor does not have a partitioner. "
57+
+ "Prediction on individual records could be slow.")
58+
}
59+
if (features.getStorageLevel == StorageLevel.NONE) {
60+
logWarning(s"$name factor is not cached. Prediction could be slow.")
61+
}
62+
}
63+
4164
/** Predict the rating of one user for one product. */
4265
def predict(user: Int, product: Int): Double = {
4366
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.recommendation
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.util.MLlibTestSparkContext
23+
import org.apache.spark.mllib.util.TestingUtils._
24+
import org.apache.spark.rdd.RDD
25+
26+
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
27+
28+
val rank = 2
29+
var userFeatures: RDD[(Int, Array[Double])] = _
30+
var prodFeatures: RDD[(Int, Array[Double])] = _
31+
32+
override def beforeAll(): Unit = {
33+
super.beforeAll()
34+
userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
35+
prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
36+
}
37+
38+
test("constructor") {
39+
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
40+
assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)
41+
42+
intercept[IllegalArgumentException] {
43+
new MatrixFactorizationModel(1, userFeatures, prodFeatures)
44+
}
45+
46+
val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
47+
intercept[IllegalArgumentException] {
48+
new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
49+
}
50+
51+
val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
52+
intercept[IllegalArgumentException] {
53+
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
54+
}
55+
}
56+
}

0 commit comments

Comments
 (0)