Skip to content

Commit ad82dac

Browse files
committed
[SPARK-23469] HashingTF should use corrected MurmurHash3 implementation
1 parent 44c28d7 commit ad82dac

7 files changed

Lines changed: 52 additions & 14 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ import org.apache.spark.ml.linalg.Vectors
2626
import org.apache.spark.ml.param._
2727
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
2828
import org.apache.spark.ml.util._
29-
import org.apache.spark.mllib.feature.HashingTF.murmur3Hash
29+
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
3030
import org.apache.spark.sql.{DataFrame, Dataset}
3131
import org.apache.spark.sql.functions.{col, udf}
3232
import org.apache.spark.sql.types.{ArrayType, StructType}
3333
import org.apache.spark.util.Utils
34+
import org.apache.spark.util.VersionUtils.majorMinorVersion
3435

3536
/**
3637
* Maps a sequence of terms to their term frequencies using the hashing trick.
@@ -44,7 +45,7 @@ import org.apache.spark.util.Utils
4445
class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
4546
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
4647

47-
private[this] val hashFunc: Any => Int = murmur3Hash
48+
private var hashFunc: Any => Int = FeatureHasher.murmur3Hash
4849

4950
@Since("1.2.0")
5051
def this() = this(Identifiable.randomUID("hashingTF"))
@@ -142,6 +143,29 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
142143
@Since("1.6.0")
143144
object HashingTF extends DefaultParamsReadable[HashingTF] {
144145

146+
private class HashingTFReader extends MLReader[HashingTF] {
147+
148+
private val className = classOf[HashingTF].getName
149+
150+
override def load(path: String): HashingTF = {
151+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
152+
val hashingTF = new HashingTF(metadata.uid)
153+
metadata.getAndSetParams(hashingTF)
154+
155+
// We support loading old `HashingTF` saved by previous Spark versions.
156+
// Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses
157+
// `ml.Feature.FeatureHasher.murmur3Hash`.
158+
val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion)
159+
if (majorVersion < 3) {
160+
hashingTF.hashFunc = OldHashingTF.murmur3Hash
161+
}
162+
hashingTF
163+
}
164+
}
165+
166+
@Since("3.0.0")
167+
override def read: MLReader[HashingTF] = new HashingTFReader
168+
145169
@Since("1.6.0")
146170
override def load(path: String): HashingTF = super.load(path)
147171
}
Binary file not shown.

mllib/src/test/resources/test-data/hashingTF-pre3.0/metadata/_SUCCESS

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"class":"org.apache.spark.ml.feature.HashingTF","timestamp":1564446310495,"sparkVersion":"2.3.0-SNAPSHOT","uid":"hashingTF_8ced2ab477c1","paramMap":{"binary":true,"numFeatures":100,"outputCol":"features","inputCol":"words"}}

mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,32 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
7575
Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
7676
assert(features ~== expected absTol 1e-14)
7777
}
78-
78+
7979
test("indexOf method") {
8080
val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
8181
val n = 100
8282
val hashingTF = new HashingTF()
8383
.setInputCol("words")
8484
.setOutputCol("features")
8585
.setNumFeatures(n)
86-
val mLlibHashingTF = new MLlibHashingTF(n)
87-
assert(hashingTF.indexOf("a") === mLlibHashingTF.indexOf("a"))
88-
assert(hashingTF.indexOf("b") === mLlibHashingTF.indexOf("b"))
89-
assert(hashingTF.indexOf("c") === mLlibHashingTF.indexOf("c"))
90-
assert(hashingTF.indexOf("d") === mLlibHashingTF.indexOf("d"))
86+
assert(hashingTF.indexOf("a") === 67)
87+
assert(hashingTF.indexOf("b") === 65)
88+
assert(hashingTF.indexOf("c") === 68)
89+
assert(hashingTF.indexOf("d") === 90)
90+
}
91+
92+
test("Load HashingTF prior to Spark 3.0") {
93+
val hashingTFPath = testFile("test-data/hashingTF-pre3.0")
94+
val loadedHashingTF = HashingTF.load(hashingTFPath)
95+
val mLlibHashingTF = new MLlibHashingTF(100)
96+
assert(loadedHashingTF.indexOf("a") === mLlibHashingTF.indexOf("a"))
97+
assert(loadedHashingTF.indexOf("b") === mLlibHashingTF.indexOf("b"))
98+
assert(loadedHashingTF.indexOf("c") === mLlibHashingTF.indexOf("c"))
99+
assert(loadedHashingTF.indexOf("d") === mLlibHashingTF.indexOf("d"))
100+
101+
val metadata = spark.read.json(s"$hashingTFPath/metadata")
102+
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
103+
assert(sparkVersionStr == "2.3.0-SNAPSHOT")
91104
}
92105

93106
test("read/write") {
@@ -103,7 +116,7 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
103116
object HashingTFSuite {
104117

105118
private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
106-
Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures)
119+
Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures)
107120
}
108121

109122
}

python/pyspark/ml/feature.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -902,19 +902,19 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
902902
>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"])
903903
>>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
904904
>>> hashingTF.transform(df).head().features
905-
SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0})
905+
SparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})
906906
>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
907-
SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0})
907+
SparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})
908908
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
909909
>>> hashingTF.transform(df, params).head().vector
910-
SparseVector(5, {0: 1.0, 1: 1.0, 2: 1.0})
910+
SparseVector(5, {0: 1.0, 2: 1.0, 3: 1.0})
911911
>>> hashingTFPath = temp_path + "/hashing-tf"
912912
>>> hashingTF.save(hashingTFPath)
913913
>>> loadedHashingTF = HashingTF.load(hashingTFPath)
914914
>>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures()
915915
True
916916
>>> hashingTF.indexOf("b")
917-
1
917+
5
918918
919919
.. versionadded:: 1.3.0
920920
"""

python/pyspark/ml/tests/test_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_apply_binary_term_freqs(self):
296296
hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
297297
output = hashingTF.transform(df)
298298
features = output.select("features").first().features.toArray()
299-
expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray()
299+
expected = Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]).toArray()
300300
for i in range(0, n):
301301
self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) +
302302
": expected " + str(expected[i]) + ", got " + str(features[i]))

0 commit comments

Comments
 (0)