Skip to content

Commit 033ae5d

Browse files
author
Yun Ni
committed
Code Review Comments
1 parent c115ed3 commit 033ae5d

5 files changed

Lines changed: 147 additions & 95 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType
3636
*
3737
* Params for [[BucketedRandomProjectionLSH]].
3838
*/
39-
private[ml] trait BucketedRandomProjectionParams extends Params {
39+
private[ml] trait BucketedRandomProjectionLSHParams extends Params {
4040

4141
/**
4242
* The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
@@ -68,18 +68,18 @@ private[ml] trait BucketedRandomProjectionParams extends Params {
6868
*/
6969
@Experimental
7070
@Since("2.1.0")
71-
class BucketedRandomProjectionModel private[ml](
71+
class BucketedRandomProjectionLSHModel private[ml](
7272
override val uid: String,
73-
@Since("2.1.0") private[ml] val randUnitVectors: Array[Vector])
74-
extends LSHModel[BucketedRandomProjectionModel] with BucketedRandomProjectionParams {
73+
private[ml] val randUnitVectors: Array[Vector])
74+
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
7575

7676
@Since("2.1.0")
7777
override protected[ml] val hashFunction: Vector => Array[Vector] = {
7878
key: Vector => {
7979
val hashValues: Array[Double] = randUnitVectors.map({
8080
randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
8181
})
82-
// TODO: For AND-amplification, output vectors of dimension numHashFunctions
82+
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
8383
hashValues.grouped(1).map(Vectors.dense).toArray
8484
}
8585
}
@@ -100,7 +100,7 @@ class BucketedRandomProjectionModel private[ml](
100100

101101
@Since("2.1.0")
102102
override def write: MLWriter = {
103-
new BucketedRandomProjectionModel.BucketedRandomProjectionModelWriter(this)
103+
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
104104
}
105105
}
106106

@@ -111,8 +111,8 @@ class BucketedRandomProjectionModel private[ml](
111111
* Euclidean distance metrics.
112112
*
113113
* The input is dense or sparse vectors, each of which represents a point in the Euclidean
114-
* distance space. The output will be vectors of configurable dimension. Hash value in the same
115-
* dimension is calculated by the same hash function.
114+
* distance space. The output will be vectors of configurable dimension. Hash values in the
115+
* same dimension are calculated by the same hash function.
116116
*
117117
* References:
118118
*
@@ -125,7 +125,8 @@ class BucketedRandomProjectionModel private[ml](
125125
@Experimental
126126
@Since("2.1.0")
127127
class BucketedRandomProjectionLSH(override val uid: String)
128-
extends LSH[BucketedRandomProjectionModel] with BucketedRandomProjectionParams with HasSeed {
128+
extends LSH[BucketedRandomProjectionLSHModel]
129+
with BucketedRandomProjectionLSHParams with HasSeed {
129130

130131
@Since("2.1.0")
131132
override def setInputCol(value: String): this.type = super.setInputCol(value)
@@ -138,7 +139,7 @@ class BucketedRandomProjectionLSH(override val uid: String)
138139

139140
@Since("2.1.0")
140141
def this() = {
141-
this(Identifiable.randomUID("random projection"))
142+
this(Identifiable.randomUID("brp-lsh"))
142143
}
143144

144145
/** @group setParam */
@@ -150,15 +151,17 @@ class BucketedRandomProjectionLSH(override val uid: String)
150151
def setSeed(value: Long): this.type = set(seed, value)
151152

152153
@Since("2.1.0")
153-
override protected[this] def createRawLSHModel(inputDim: Int): BucketedRandomProjectionModel = {
154+
override protected[this] def createRawLSHModel(
155+
inputDim: Int
156+
): BucketedRandomProjectionLSHModel = {
154157
val rand = new Random($(seed))
155158
val randUnitVectors: Array[Vector] = {
156159
Array.fill($(numHashTables)) {
157160
val randArray = Array.fill(inputDim)(rand.nextGaussian())
158161
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
159162
}
160163
}
161-
new BucketedRandomProjectionModel(uid, randUnitVectors)
164+
new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
162165
}
163166

164167
@Since("2.1.0")
@@ -179,18 +182,18 @@ object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomP
179182
}
180183

181184
@Since("2.1.0")
182-
object BucketedRandomProjectionModel extends MLReadable[BucketedRandomProjectionModel] {
185+
object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {
183186

184187
@Since("2.1.0")
185-
override def read: MLReader[BucketedRandomProjectionModel] = {
186-
new BucketedRandomProjectionModelReader
188+
override def read: MLReader[BucketedRandomProjectionLSHModel] = {
189+
new BucketedRandomProjectionLSHModelReader
187190
}
188191

189192
@Since("2.1.0")
190-
override def load(path: String): BucketedRandomProjectionModel = super.load(path)
193+
override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)
191194

192-
private[BucketedRandomProjectionModel] class BucketedRandomProjectionModelWriter(
193-
instance: BucketedRandomProjectionModel) extends MLWriter {
195+
private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
196+
instance: BucketedRandomProjectionLSHModel) extends MLWriter {
194197

195198
// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
196199
private case class Data(randUnitVectors: Matrix)
@@ -208,21 +211,22 @@ object BucketedRandomProjectionModel extends MLReadable[BucketedRandomProjection
208211
}
209212
}
210213

211-
private class BucketedRandomProjectionModelReader
212-
extends MLReader[BucketedRandomProjectionModel] {
214+
private class BucketedRandomProjectionLSHModelReader
215+
extends MLReader[BucketedRandomProjectionLSHModel] {
213216

214217
/** Checked against metadata when loading model */
215-
private val className = classOf[BucketedRandomProjectionModel].getName
218+
private val className = classOf[BucketedRandomProjectionLSHModel].getName
216219

217-
override def load(path: String): BucketedRandomProjectionModel = {
220+
override def load(path: String): BucketedRandomProjectionLSHModel = {
218221
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
219222

220223
val dataPath = new Path(path, "data").toString
221224
val data = sparkSession.read.parquet(dataPath)
222225
val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
223226
.select("randUnitVectors")
224227
.head()
225-
val model = new BucketedRandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
228+
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
229+
randUnitVectors.rowIter.toArray)
226230

227231
DefaultParamsReader.getAndSetParams(model, metadata)
228232
model

mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ import org.apache.spark.sql.types._
3333
*/
3434
private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
3535
/**
36-
* Param for the dimension of LSH OR-amplification.
36+
* Param for the number of hash tables used in LSH OR-amplification.
3737
*
38-
* LSH OR-amplification can be used to reduce the false negative rate. The higher the dimension
39-
* is, the lower the false negative rate.
38+
* LSH OR-amplification can be used to reduce the false negative rate. Higher values for this
39+
* param lead to a reduced false negative rate, at the expense of added computational complexity.
4040
* @group param
4141
*/
4242
final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " +
@@ -66,7 +66,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
6666
self: T =>
6767

6868
/**
69-
* The hash function of LSH, mapping an input feature to multiple vectors
69+
* The hash function of LSH, mapping an input feature vector to multiple hash vectors.
7070
* @return The mapping of LSH function.
7171
*/
7272
protected[ml] val hashFunction: Vector => Array[Vector]
@@ -99,26 +99,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
9999
validateAndTransformSchema(schema)
100100
}
101101

102-
/**
103-
* Given a large dataset and an item, approximately find at most k items which have the closest
104-
* distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
105-
* the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
106-
* transformed data when necessary.
107-
*
108-
* This method implements two ways of fetching k nearest neighbors:
109-
* - Single-probe: Fast, return at most k elements (Probing only one buckets)
110-
* - Multi-probe: Slow, return exact k elements (Probing multiple buckets close to the key)
111-
*
112-
* Currently it is made private since more discussion is needed for Multi-probe
113-
*
114-
* @param dataset the dataset to search for nearest neighbors of the key
115-
* @param key Feature vector representing the item to search for
116-
* @param numNearestNeighbors The maximum number of nearest neighbors
117-
* @param singleProbe True for using single-probe; false for multi-probe
118-
* @param distCol Output column for storing the distance between each result row and the key
119-
* @return A dataset containing at most k items closest to the key. A distCol is added to show
120-
* the distance between each row and the key.
121-
*/
102+
// TODO: Fix the MultiProbe NN Search in SPARK-18454
122103
private[feature] def approxNearestNeighbors(
123104
dataset: Dataset[_],
124105
key: Vector,
@@ -179,7 +160,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
179160
* @return A dataset containing at most k items closest to the key. A distCol is added to show
180161
* the distance between each row and the key.
181162
*/
182-
private[feature] def approxNearestNeighbors(
163+
def approxNearestNeighbors(
183164
dataset: Dataset[_],
184165
key: Vector,
185166
numNearestNeighbors: Int,

mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,31 @@ import org.apache.spark.sql.types.StructType
3232
* :: Experimental ::
3333
*
3434
* Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function is
35-
* a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`:
36-
* `h_i(x) = ((x \cdot k_i) \mod prime) \mod numEntries`
35+
* a perfect hash function for a specific set `S` with cardinality equal to `numEntries`:
36+
* `h_i(x) = ((x \cdot a_i + b_i) \mod prime) \mod numEntries`
3737
*
3838
* @param numEntries The number of entries of the hash functions.
3939
* @param randCoefficients An array of random coefficients, each used by one hash function.
4040
*/
4141
@Experimental
4242
@Since("2.1.0")
43-
class MinHashModel private[ml] (
43+
class MinHashLSHModel private[ml](
4444
override val uid: String,
45-
@Since("2.1.0") private[ml] val numEntries: Int,
46-
@Since("2.1.0") private[ml] val randCoefficients: Array[Int])
47-
extends LSHModel[MinHashModel] {
45+
private[ml] val numEntries: Int,
46+
private[ml] val randCoefficients: Array[(Int, Int)])
47+
extends LSHModel[MinHashLSHModel] {
4848

4949
@Since("2.1.0")
5050
override protected[ml] val hashFunction: Vector => Array[Vector] = {
5151
elems: Vector => {
5252
require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
5353
val elemsList = elems.toSparse.indices.toList
54-
val hashValues = randCoefficients.map({ randCoefficient: Int =>
55-
elemsList.map({ elem: Int =>
56-
(1 + elem) * randCoefficient.toLong % MinHashLSH.prime % numEntries
57-
}).min.toDouble
54+
val hashValues = randCoefficients.map({ case (a: Int, b: Int) =>
55+
elemsList.map { elem: Int =>
56+
((1 + elem) * a + b) % MinHashLSH.HASH_PRIME % numEntries
57+
}.min.toDouble
5858
})
59-
// TODO: For AND-amplification, output vectors of dimension numHashFunctions
59+
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
6060
hashValues.grouped(1).map(Vectors.dense).toArray
6161
}
6262
}
@@ -74,7 +74,7 @@ class MinHashModel private[ml] (
7474
@Since("2.1.0")
7575
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
7676
// Since it's generated by hashing, it will be a pair of dense vectors.
77-
// TODO: This hashDistance function is controversial. Requires more discussion.
77+
// TODO: This hashDistance function requires more discussion in SPARK-18454
7878
x.zip(y).map(vectorPair =>
7979
vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
8080
).min
@@ -84,7 +84,7 @@ class MinHashModel private[ml] (
8484
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
8585

8686
@Since("2.1.0")
87-
override def write: MLWriter = new MinHashModel.MinHashModelWriter(this)
87+
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
8888
}
8989

9090
/**
@@ -93,17 +93,17 @@ class MinHashModel private[ml] (
9393
* LSH class for Jaccard distance.
9494
*
9595
* The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
96-
* `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])`
97-
* means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
98-
* Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated
99-
* as binary "1" values.
96+
* `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))`
97+
* means there are 10 elements in the space. This set contains non-zero values at indices 2, 3, and
98+
* 5. Also, any input vector must have at least 1 non-zero index, and all non-zero values are
99+
* treated as binary "1" values.
100100
*
101101
* References:
102102
* [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]]
103103
*/
104104
@Experimental
105105
@Since("2.1.0")
106-
class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSeed {
106+
class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with HasSeed {
107107

108108
@Since("2.1.0")
109109
override def setInputCol(value: String): this.type = super.setInputCol(value)
@@ -116,21 +116,23 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee
116116

117117
@Since("2.1.0")
118118
def this() = {
119-
this(Identifiable.randomUID("min hash"))
119+
this(Identifiable.randomUID("mh-lsh"))
120120
}
121121

122122
/** @group setParam */
123123
@Since("2.1.0")
124124
def setSeed(value: Long): this.type = set(seed, value)
125125

126126
@Since("2.1.0")
127-
override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = {
128-
require(inputDim <= MinHashLSH.prime / 2,
129-
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.prime / 2}.")
127+
override protected[ml] def createRawLSHModel(inputDim: Int): MinHashLSHModel = {
128+
require(inputDim <= MinHashLSH.HASH_PRIME,
129+
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.")
130130
val rand = new Random($(seed))
131-
val numEntry = inputDim * 2
132-
val randCoofs: Array[Int] = Array.fill($(numHashTables))(1 + rand.nextInt(MinHashLSH.prime - 1))
133-
new MinHashModel(uid, numEntry, randCoofs)
131+
val numEntry = inputDim
132+
val randCoefs: Array[(Int, Int)] = Array.fill(2 * $(numHashTables)) {
133+
(1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1))
134+
}
135+
new MinHashLSHModel(uid, numEntry, randCoefs)
134136
}
135137

136138
@Since("2.1.0")
@@ -146,46 +148,49 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee
146148
@Since("2.1.0")
147149
object MinHashLSH extends DefaultParamsReadable[MinHashLSH] {
148150
// A large prime smaller than sqrt(2^63 − 1)
149-
private[ml] val prime = 2038074743
151+
private[ml] val HASH_PRIME = 2038074743
150152

151153
@Since("2.1.0")
152154
override def load(path: String): MinHashLSH = super.load(path)
153155
}
154156

155157
@Since("2.1.0")
156-
object MinHashModel extends MLReadable[MinHashModel] {
158+
object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
157159

158160
@Since("2.1.0")
159-
override def read: MLReader[MinHashModel] = new MinHashModelReader
161+
override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader
160162

161163
@Since("2.1.0")
162-
override def load(path: String): MinHashModel = super.load(path)
164+
override def load(path: String): MinHashLSHModel = super.load(path)
163165

164-
private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter {
166+
private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel)
167+
extends MLWriter {
165168

166169
private case class Data(numEntries: Int, randCoefficients: Array[Int])
167170

168171
override protected def saveImpl(path: String): Unit = {
169172
DefaultParamsWriter.saveMetadata(instance, path, sc)
170-
val data = Data(instance.numEntries, instance.randCoefficients)
173+
val data = Data(instance.numEntries, instance.randCoefficients
174+
.flatMap(tuple => Array(tuple._1, tuple._2)))
171175
val dataPath = new Path(path, "data").toString
172176
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
173177
}
174178
}
175179

176-
private class MinHashModelReader extends MLReader[MinHashModel] {
180+
private class MinHashLSHModelReader extends MLReader[MinHashLSHModel] {
177181

178182
/** Checked against metadata when loading model */
179-
private val className = classOf[MinHashModel].getName
183+
private val className = classOf[MinHashLSHModel].getName
180184

181-
override def load(path: String): MinHashModel = {
185+
override def load(path: String): MinHashLSHModel = {
182186
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
183187

184188
val dataPath = new Path(path, "data").toString
185189
val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head()
186190
val numEntries = data.getAs[Int](0)
187-
val randCoefficients = data.getAs[Seq[Int]](1).toArray
188-
val model = new MinHashModel(metadata.uid, numEntries, randCoefficients)
191+
val randCoefficients = data.getAs[Seq[Int]](1).grouped(2)
192+
.map(tuple => (tuple(0), tuple(1))).toArray
193+
val model = new MinHashLSHModel(metadata.uid, numEntries, randCoefficients)
189194

190195
DefaultParamsReader.getAndSetParams(model, metadata)
191196
model

0 commit comments

Comments
 (0)