-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18081][ML][DOCS] Add user guide for Locality Sensitive Hashing(LSH) #15795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
8c7971b
84877ee
19404c4
6654f8b
40a0caa
a78d920
7e60b76
b9f716d
19653d1
a048194
7922117
7c09f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1396,3 +1396,134 @@ for more details on the API. | |
| {% include_example python/ml/chisq_selector_example.py %} | ||
| </div> | ||
| </div> | ||
|
|
||
| # Locality Sensitive Hashing | ||
| [Locality Sensitive Hashing(LSH)](https://en.wikipedia.org/wiki/Locality-sensitive_hashing) is a class of dimension reduction hash families, which can be used as both feature transformation and machine-learned ranking. Difference distance metric has its own LSH family class in `spark.ml`, which can transform feature columns to hash values as new columns. Besides feature transforming, `spark.ml` also implemented approximate nearest neighbor algorithm and approximate similarity join algorithm using LSH. | ||
|
|
||
| In this section, we call a pair of input features a false positive if the two features are hashed into the same hash bucket but they are far away in distance, and we define false negative as the pair of features when their distance are close but they are not in the same hash bucket. | ||
|
|
||
| ## Random Projection for Euclidean Distance | ||
| **Note:** Please note that this is different than the [Random Projection for cosine distance](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Random_projection). | ||
|
||
|
|
||
| [Random Projection](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions) is the LSH family in `spark.ml` for Euclidean distance. The Euclidean distance is defined as follows: | ||
|
||
| `\[ | ||
| d(\mathbf{x}, \mathbf{y}) = \sqrt{\sum_i (x_i - y_i)^2} | ||
| \]` | ||
| Its LSH family projects features onto a random unit vector and divide the projected results to hash buckets: | ||
| `\[ | ||
| h(\mathbf{x}) = \lfloor \frac{\mathbf{x} \cdot \mathbf{v}}{r} \rfloor | ||
| \]` | ||
| where `v` is a normalized random unit vector and `r` is user-defined bucket length. | ||
|
|
||
| The input features in Euclidean space are represented in vectors. Both sparse and dense vectors are supported. | ||
|
||
|
|
||
| The bucket length can be used to trade off the performance of random projection. A larger bucket lowers the false negative rate but usually increases running time and false positive rate. | ||
|
||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| Refer to the [RandomProjection Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RandomProjection) | ||
|
||
| for more details on the API. | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/RandomProjectionExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| Refer to the [RandomProjection Java docs](api/java/org/apache/spark/ml/feature/RandomProjection.html) | ||
|
||
| for more details on the API. | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaRandomProjectionExample.java %} | ||
| </div> | ||
| </div> | ||
|
|
||
| ## MinHash for Jaccard Distance | ||
|
||
| [MinHash](https://en.wikipedia.org/wiki/MinHash) is the LSH family in `spark.ml` for Jaccard distance where input features are sets of natural numbers. Jaccard distance of two sets is defined by the cardinality of their intersection and union: | ||
| `\[ | ||
| d(\mathbf{A}, \mathbf{B}) = 1 - \frac{|\mathbf{A} \cap \mathbf{B}|}{|\mathbf{A} \cup \mathbf{B}|} | ||
| \]` | ||
| As its LSH family, MinHash applies a random [perfect hash function](https://en.wikipedia.org/wiki/Perfect_hash_function) `g` to each elements in the set and take the minimum of all hashed values: | ||
| `\[ | ||
| h(\mathbf{A}) = \min_{a \in \mathbf{A}}(g(a)) | ||
| \]` | ||
|
|
||
| Input sets for MinHash is represented in vectors which dimension equals the total number of elements in the space. Each dimension of the vectors represents the status of an elements: zero value means the elements is not in the set; non-zero value means the set contains the corresponding elements. For example, `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])` means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5. | ||
|
||
|
|
||
| **Note:** Empty sets cannot be transformed by MinHash, which means any input vector must have at least 1 non-zero indices. | ||
|
||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| Refer to the [MinHash Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinHash) | ||
|
||
| for more details on the API. | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/MinHashExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| Refer to the [MinHash Java docs](api/java/org/apache/spark/ml/feature/MinHash.html) | ||
| for more details on the API. | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaMinHashExample.java %} | ||
| </div> | ||
| </div> | ||
|
|
||
| ## Feature Transformation | ||
|
||
| Feature Transformation is the base functionality to add hash results as a new column. Users can specify input column name and output column name by setting `inputCol` and `outputCol`. Also in `spark.ml`, all LSH families can pick multiple LSH hash functions. The number of hash functions could be set in `outputDim` in any LSH family. | ||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/LSHTransformationExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaLSHTransformationExample.java %} | ||
| </div> | ||
| </div> | ||
|
|
||
| When multiple hash functions are picked, it's very useful for users to apply [amplification](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Amplification) to trade off between false positive and false negative rate. | ||
|
||
| * AND-amplifications: Two input vectors are defined to be in the same bucket only if ALL of the hash values match. This will decrease the false positive rate but increase the false negative rate. | ||
|
||
| * OR-amplifications: Two input vectors are defined to be in the same bucket as long as ANY one of the hash value matches. This will increase the false positive rate but decrease the false negative rate. | ||
|
|
||
|
||
| Approximate nearest neighbor and approximate similarity join use OR-amplification. | ||
|
|
||
| ## Approximate Similarity Join | ||
| Approximate similarity join takes two datasets, and approximately returns row pairs which distance is smaller than a user-defined threshold. Approximate Similarity Join supports both joining two different datasets and self joining. | ||
|
||
|
|
||
| Approximate similarity join allows users to cache the transformed columns when necessary: If the `outputCol` is missing, the method will transform the data; if the `outputCol` exists, it will use the `outputCol` directly. | ||
|
||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/ApproxSimilarityJoinExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaApproxSimilarityJoinExample.java %} | ||
| </div> | ||
| </div> | ||
|
|
||
| ## Approximate Nearest Neighbor Search | ||
| Approximate nearest neighbor search takes a dataset and a key, and approximately returns a number of rows in the dataset that are closest to the key. The number of rows to return are defined by user. | ||
|
||
|
|
||
| Approximate nearest neighbor search allows users to cache the transformed columns when necessary: If the `outputCol` is missing, the method will transform the data; if the `outputCol` exists, it will use the `outputCol` directly. | ||
|
||
|
|
||
| There are two methods of approximate nearest neighbor search implemented in `spark.ml`: | ||
| * Single probing search: Only the hash bucket(s) where the key is hashed are searched. This method is time efficient but might return less than k rows. | ||
| * Multi probing search: All nearby hash buckets are searched to ensure exactly k rows are returned when possible, but this method can take more time. | ||
| The example code will show the difference between single and multi probing. | ||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/ApproxNearestNeighborExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaApproxNearestNeighborExample.java %} | ||
| </div> | ||
| </div> | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.examples.ml; | ||
|
|
||
| import org.apache.spark.sql.SparkSession; | ||
|
|
||
| // $example on$ | ||
| import java.util.Arrays; | ||
| import java.util.List; | ||
|
|
||
| import org.apache.spark.ml.feature.MinHash; | ||
| import org.apache.spark.ml.feature.MinHashModel; | ||
| import org.apache.spark.ml.linalg.Vector; | ||
| import org.apache.spark.ml.linalg.VectorUDT; | ||
| import org.apache.spark.ml.linalg.Vectors; | ||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.RowFactory; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.Metadata; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
| // $example off$ | ||
|
|
||
| public class JavaApproxNearestNeighborExample { | ||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("JavaApproxNearestNeighborExample") | ||
| .getOrCreate(); | ||
|
|
||
| // $example on$ | ||
| List<Row> data = Arrays.asList( | ||
| RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0})) | ||
| ); | ||
|
|
||
| StructType schema = new StructType(new StructField[]{ | ||
| new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), | ||
| new StructField("keys", new VectorUDT(), false, Metadata.empty()) | ||
| }); | ||
| Dataset<Row> dataFrame = spark.createDataFrame(data, schema); | ||
|
|
||
| MinHash mh = new MinHash() | ||
| .setOutputDim(5) | ||
| .setInputCol("keys") | ||
| .setOutputCol("values"); | ||
|
|
||
| Vector key1 = Vectors.sparse(6, new int[]{1, 3}, new double[]{1.0, 1.0, 1.0}); | ||
| Vector key2 = Vectors.sparse(6, new int[]{5}, new double[]{1.0, 1.0, 1.0}); | ||
|
|
||
| MinHashModel model = mh.fit(dataFrame); | ||
| model.approxNearestNeighbors(dataFrame, key1, 2).show(); | ||
|
|
||
| System.out.println("Difference between single probing and multi probing:"); | ||
|
|
||
| System.out.println("Single probing sometimes returns less than k rows"); | ||
| model.approxNearestNeighbors(dataFrame, key2, 3, true, "distCol").show(); | ||
|
|
||
| System.out.println("Multi probing returns exact k rows whenever possible"); | ||
| model.approxNearestNeighbors(dataFrame, key2, 3, false, "distCol").show(); | ||
|
|
||
| System.out.println("Multi probing returns the whole dataset when there are not enough rows"); | ||
| model.approxNearestNeighbors(dataFrame, key2, 4, false, "distCol").show(); | ||
| // $example off$ | ||
|
|
||
| spark.stop(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.examples.ml; | ||
|
|
||
| import org.apache.spark.sql.SparkSession; | ||
|
|
||
| // $example on$ | ||
| import java.util.Arrays; | ||
| import java.util.List; | ||
|
|
||
| import org.apache.spark.ml.feature.MinHash; | ||
| import org.apache.spark.ml.feature.MinHashModel; | ||
| import org.apache.spark.ml.linalg.VectorUDT; | ||
| import org.apache.spark.ml.linalg.Vectors; | ||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.RowFactory; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.Metadata; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
| // $example off$ | ||
|
|
||
| public class JavaApproxSimilarityJoinExample { | ||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("JavaApproxNearestNeighborExample") | ||
| .getOrCreate(); | ||
|
|
||
| // $example on$ | ||
| List<Row> dataA = Arrays.asList( | ||
| RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0})) | ||
| ); | ||
|
|
||
| List<Row> dataB = Arrays.asList( | ||
| RowFactory.create(0, Vectors.sparse(6, new int[]{1, 3, 5}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 5}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(2, Vectors.sparse(6, new int[]{1, 2, 4}, new double[]{1.0, 1.0, 1.0})) | ||
| ); | ||
|
|
||
| StructType schema = new StructType(new StructField[]{ | ||
| new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), | ||
| new StructField("keys", new VectorUDT(), false, Metadata.empty()) | ||
| }); | ||
| Dataset<Row> dfA = spark.createDataFrame(dataA, schema); | ||
| Dataset<Row> dfB = spark.createDataFrame(dataB, schema); | ||
|
|
||
| MinHash mh = new MinHash() | ||
| .setOutputDim(5) | ||
| .setInputCol("keys") | ||
| .setOutputCol("values"); | ||
|
|
||
| MinHashModel model = mh.fit(dfA); | ||
| model.approxSimilarityJoin(dfA, dfB, 0.6).show(); | ||
|
|
||
| // Cache the transformed columns | ||
| Dataset<Row> transformedA = model.transform(dfA); | ||
| Dataset<Row> transformedB = model.transform(dfB); | ||
| model.approxSimilarityJoin(transformedA, transformedB, 0.6).show(); | ||
|
|
||
| // Self Join | ||
| model.approxSimilarityJoin(dfA, dfA, 0.6).filter("datasetA.id < datasetB.id").show(); | ||
| // $example off$ | ||
|
|
||
| spark.stop(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.examples.ml; | ||
|
|
||
| import org.apache.spark.sql.SparkSession; | ||
|
|
||
| // $example on$ | ||
| import java.util.Arrays; | ||
| import java.util.List; | ||
|
|
||
| import org.apache.spark.ml.feature.MinHash; | ||
| import org.apache.spark.ml.feature.MinHashModel; | ||
| import org.apache.spark.ml.linalg.VectorUDT; | ||
| import org.apache.spark.ml.linalg.Vectors; | ||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.RowFactory; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.Metadata; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
| // $example off$ | ||
|
|
||
| public class JavaLSHTransformationExample { | ||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("JavaLSHTransformationExample") | ||
| .getOrCreate(); | ||
|
|
||
| // $example on$ | ||
| List<Row> data = Arrays.asList( | ||
| RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})), | ||
| RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0})) | ||
| ); | ||
|
|
||
| StructType schema = new StructType(new StructField[]{ | ||
| new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), | ||
| new StructField("keys", new VectorUDT(), false, Metadata.empty()) | ||
| }); | ||
| Dataset<Row> dataFrame = spark.createDataFrame(data, schema); | ||
|
|
||
| // Single LSH hashing | ||
| MinHash mhSingleHash = new MinHash() | ||
| .setOutputDim(1) | ||
| .setInputCol("keys") | ||
| .setOutputCol("values"); | ||
| MinHashModel modelSingleHash = mhSingleHash.fit(dataFrame); | ||
| // Feature transformation: add a new column for a hash value | ||
| modelSingleHash.transform(dataFrame).show(); | ||
|
|
||
| // Use more than 1 hash functions | ||
| MinHash mh = new MinHash() | ||
| .setOutputDim(5) | ||
| .setInputCol("keys") | ||
| .setOutputCol("values"); | ||
| MinHashModel model = mh.fit(dataFrame); | ||
| // Feature Transformation: add a new column for multiple hash values | ||
| model.transform(dataFrame).show(); | ||
| // $example off$ | ||
|
|
||
| spark.stop(); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Despite the opening sentence of the wikipedia article, I wouldn't class LSH as a dimensionality reduction technique? It's a set of hashing techniques where the hash preserves some properties. Maybe it's just my taste. But the rest of the text talks about the output as hash values.
What does "machine-learned ranking" refer to here? as this isn't a ranking technique per se.
I think this is missing a broad summary statement that indicates why LSH is even of interest: it provides a hash function where hashed values are in some sense close when the input values are close according to some metric. And then the variations below plug in different definitions of "close" and "input".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rephrased. PTAL