@@ -32,31 +32,31 @@ import org.apache.spark.sql.types.StructType
3232 * :: Experimental ::
3333 *
3434 * Model produced by [[MinHashLSH ]], where multiple hash functions are stored. Each hash function is
35- * a perfect hash function for a specific set `S` with cardinality equal to a half of `numEntries`:
36- * `h_i(x) = ((x \cdot k_i ) \mod prime) \mod numEntries`
35+ * a perfect hash function for a specific set `S` with cardinality equal to `numEntries`:
36+ * `h_i(x) = ((x \cdot a_i + b_i ) \mod prime) \mod numEntries`
3737 *
3838 * @param numEntries The number of entries of the hash functions.
3939 * @param randCoefficients An array of random coefficients, each used by one hash function.
4040 */
4141@ Experimental
4242@ Since (" 2.1.0" )
43- class MinHashModel private [ml] (
43+ class MinHashLSHModel private [ml](
4444 override val uid : String ,
45- @ Since ( " 2.1.0 " ) private [ml] val numEntries : Int ,
46- @ Since ( " 2.1.0 " ) private [ml] val randCoefficients : Array [Int ])
47- extends LSHModel [MinHashModel ] {
45+ private [ml] val numEntries : Int ,
46+ private [ml] val randCoefficients : Array [( Int , Int ) ])
47+ extends LSHModel [MinHashLSHModel ] {
4848
4949 @ Since (" 2.1.0" )
5050 override protected [ml] val hashFunction : Vector => Array [Vector ] = {
5151 elems : Vector => {
5252 require(elems.numNonzeros > 0 , " Must have at least 1 non zero entry." )
5353 val elemsList = elems.toSparse.indices.toList
54- val hashValues = randCoefficients.map({ randCoefficient : Int =>
55- elemsList.map( { elem : Int =>
56- (1 + elem) * randCoefficient.toLong % MinHashLSH .prime % numEntries
57- }) .min.toDouble
54+ val hashValues = randCoefficients.map({ case ( a : Int , b : Int ) =>
55+ elemsList.map { elem : Int =>
56+ (( 1 + elem) * a + b) % MinHashLSH .HASH_PRIME % numEntries
57+ }.min.toDouble
5858 })
59- // TODO: For AND-amplification, output vectors of dimension numHashFunctions
59+ // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
6060 hashValues.grouped(1 ).map(Vectors .dense).toArray
6161 }
6262 }
@@ -74,7 +74,7 @@ class MinHashModel private[ml] (
7474 @ Since (" 2.1.0" )
7575 override protected [ml] def hashDistance (x : Seq [Vector ], y : Seq [Vector ]): Double = {
7676 // Since it's generated by hashing, it will be a pair of dense vectors.
77- // TODO: This hashDistance function is controversial. Requires more discussion.
77+ // TODO: This hashDistance function requires more discussion in SPARK-18454
7878 x.zip(y).map(vectorPair =>
7979 vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
8080 ).min
@@ -84,7 +84,7 @@ class MinHashModel private[ml] (
8484 override def copy (extra : ParamMap ): this .type = defaultCopy(extra)
8585
8686 @ Since (" 2.1.0" )
87- override def write : MLWriter = new MinHashModel . MinHashModelWriter (this )
87+ override def write : MLWriter = new MinHashLSHModel . MinHashLSHModelWriter (this )
8888}
8989
9090/**
@@ -93,17 +93,17 @@ class MinHashModel private[ml] (
9393 * LSH class for Jaccard distance.
9494 *
9595 * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
96- * `Vectors.sparse(10, Array[( 2, 1.0), (3, 1.0), (5, 1.0)] )`
97- * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
98- * Also, any input vector must have at least 1 non-zero indices , and all non-zero values are treated
99- * as binary "1" values.
96+ * `Vectors.sparse(10, Array(( 2, 1.0), (3, 1.0), (5, 1.0)) )`
97+ * means there are 10 elements in the space. This set contains non-zero values at indices 2, 3, and
98+ * 5. Also, any input vector must have at least 1 non-zero index , and all non-zero values are
99+ * treated as binary "1" values.
100100 *
101101 * References:
102102 * [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash ]]
103103 */
104104@ Experimental
105105@ Since (" 2.1.0" )
106- class MinHashLSH (override val uid : String ) extends LSH [MinHashModel ] with HasSeed {
106+ class MinHashLSH (override val uid : String ) extends LSH [MinHashLSHModel ] with HasSeed {
107107
108108 @ Since (" 2.1.0" )
109109 override def setInputCol (value : String ): this .type = super .setInputCol(value)
@@ -116,21 +116,23 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee
116116
117117 @ Since (" 2.1.0" )
118118 def this () = {
119- this (Identifiable .randomUID(" min hash " ))
119+ this (Identifiable .randomUID(" mh-lsh " ))
120120 }
121121
122122 /** @group setParam */
123123 @ Since (" 2.1.0" )
124124 def setSeed (value : Long ): this .type = set(seed, value)
125125
126126 @ Since (" 2.1.0" )
127- override protected [ml] def createRawLSHModel (inputDim : Int ): MinHashModel = {
128- require(inputDim <= MinHashLSH .prime / 2 ,
129- s " The input vector dimension $inputDim exceeds the threshold ${MinHashLSH .prime / 2 }. " )
127+ override protected [ml] def createRawLSHModel (inputDim : Int ): MinHashLSHModel = {
128+ require(inputDim <= MinHashLSH .HASH_PRIME ,
129+ s " The input vector dimension $inputDim exceeds the threshold ${MinHashLSH .HASH_PRIME }. " )
130130 val rand = new Random ($(seed))
131- val numEntry = inputDim * 2
132- val randCoofs : Array [Int ] = Array .fill($(numHashTables))(1 + rand.nextInt(MinHashLSH .prime - 1 ))
133- new MinHashModel (uid, numEntry, randCoofs)
131+ val numEntry = inputDim
132+ val randCoefs : Array [(Int , Int )] = Array .fill(2 * $(numHashTables)) {
133+ (1 + rand.nextInt(MinHashLSH .HASH_PRIME - 1 ), rand.nextInt(MinHashLSH .HASH_PRIME - 1 ))
134+ }
135+ new MinHashLSHModel (uid, numEntry, randCoefs)
134136 }
135137
136138 @ Since (" 2.1.0" )
@@ -146,46 +148,49 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashModel] with HasSee
146148@ Since (" 2.1.0" )
147149object MinHashLSH extends DefaultParamsReadable [MinHashLSH ] {
148150 // A large prime smaller than sqrt(2^63 − 1)
149- private [ml] val prime = 2038074743
151+ private [ml] val HASH_PRIME = 2038074743
150152
151153 @ Since (" 2.1.0" )
152154 override def load (path : String ): MinHashLSH = super .load(path)
153155}
154156
155157@ Since (" 2.1.0" )
156- object MinHashModel extends MLReadable [MinHashModel ] {
158+ object MinHashLSHModel extends MLReadable [MinHashLSHModel ] {
157159
158160 @ Since (" 2.1.0" )
159- override def read : MLReader [MinHashModel ] = new MinHashModelReader
161+ override def read : MLReader [MinHashLSHModel ] = new MinHashLSHModelReader
160162
161163 @ Since (" 2.1.0" )
162- override def load (path : String ): MinHashModel = super .load(path)
164+ override def load (path : String ): MinHashLSHModel = super .load(path)
163165
164- private [MinHashModel ] class MinHashModelWriter (instance : MinHashModel ) extends MLWriter {
166+ private [MinHashLSHModel ] class MinHashLSHModelWriter (instance : MinHashLSHModel )
167+ extends MLWriter {
165168
166169 private case class Data (numEntries : Int , randCoefficients : Array [Int ])
167170
168171 override protected def saveImpl (path : String ): Unit = {
169172 DefaultParamsWriter .saveMetadata(instance, path, sc)
170- val data = Data (instance.numEntries, instance.randCoefficients)
173+ val data = Data (instance.numEntries, instance.randCoefficients
174+ .flatMap(tuple => Array (tuple._1, tuple._2)))
171175 val dataPath = new Path (path, " data" ).toString
172176 sparkSession.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
173177 }
174178 }
175179
176- private class MinHashModelReader extends MLReader [MinHashModel ] {
180+ private class MinHashLSHModelReader extends MLReader [MinHashLSHModel ] {
177181
178182 /** Checked against metadata when loading model */
179- private val className = classOf [MinHashModel ].getName
183+ private val className = classOf [MinHashLSHModel ].getName
180184
181- override def load (path : String ): MinHashModel = {
185+ override def load (path : String ): MinHashLSHModel = {
182186 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
183187
184188 val dataPath = new Path (path, " data" ).toString
185189 val data = sparkSession.read.parquet(dataPath).select(" numEntries" , " randCoefficients" ).head()
186190 val numEntries = data.getAs[Int ](0 )
187- val randCoefficients = data.getAs[Seq [Int ]](1 ).toArray
188- val model = new MinHashModel (metadata.uid, numEntries, randCoefficients)
191+ val randCoefficients = data.getAs[Seq [Int ]](1 ).grouped(2 )
192+ .map(tuple => (tuple(0 ), tuple(1 ))).toArray
193+ val model = new MinHashLSHModel (metadata.uid, numEntries, randCoefficients)
189194
190195 DefaultParamsReader .getAndSetParams(model, metadata)
191196 model
0 commit comments