Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -1396,3 +1396,134 @@ for more details on the API.
{% include_example python/ml/chisq_selector_example.py %}
</div>
</div>

# Locality Sensitive Hashing
[Locality Sensitive Hashing(LSH)](https://en.wikipedia.org/wiki/Locality-sensitive_hashing) is a class of dimension reduction hash families, which can be used as both feature transformation and machine-learned ranking. Difference distance metric has its own LSH family class in `spark.ml`, which can transform feature columns to hash values as new columns. Besides feature transforming, `spark.ml` also implemented approximate nearest neighbor algorithm and approximate similarity join algorithm using LSH.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite the opening sentence of the wikipedia article, I wouldn't class LSH as a dimensionality reduction technique? It's a set of hashing techniques where the hash preserves some properties. Maybe it's just my taste. But the rest of the text talks about the output as hash values.

What does "machine-learned ranking" refer to here? as this isn't a ranking technique per se.

I think this is missing a broad summary statement that indicates why LSH is even of interest: it provides a hash function where hashed values are in some sense close when the input values are close according to some metric. And then the variations below plug in different definitions of "close" and "input".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rephrased. PTAL


In this section, we call a pair of input features a false positive if the two features are hashed into the same hash bucket but they are far away in distance, and we define false negative as the pair of features when their distance are close but they are not in the same hash bucket.

## Random Projection for Euclidean Distance
**Note:** Please note that this is different than the [Random Projection for cosine distance](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Random_projection).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an aside, but this random projection + cosine distance technique is the main thing I think of when I think of "LSH". Is that not implemented?

Copy link
Contributor Author

@Yunni Yunni Nov 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


[Random Projection](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions) is the LSH family in `spark.ml` for Euclidean distance. The Euclidean distance is defined as follows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can say something like "also referred to as p-stable distributions" or similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is 2-stable distribution is how we choose the random vector, not the name of metric space or the hash function name.

See: https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions

`\[
d(\mathbf{x}, \mathbf{y}) = \sqrt{\sum_i (x_i - y_i)^2}
\]`
Its LSH family projects features onto a random unit vector and divide the projected results to hash buckets:
`\[
h(\mathbf{x}) = \lfloor \frac{\mathbf{x} \cdot \mathbf{v}}{r} \rfloor
\]`
where `v` is a normalized random unit vector and `r` is user-defined bucket length.

The input features in Euclidean space are represented in vectors. Both sparse and dense vectors are supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little unclear - should we say "RandomProjection accepts arbitrary vectors as input features, and supports both sparse and dense vectors".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


The bucket length can be used to trade off the performance of random projection. A larger bucket lowers the false negative rate but usually increases running time and false positive rate.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the point here that near vectors end up in nearby buckets? I feel like this is skipping the intuition of why you would care about this technique or when you'd use it. Not like this needs a whole paragraph, just a few sentences of pointers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. PTAL


<div class="codetabs">
<div data-lang="scala" markdown="1">

Refer to the [RandomProjection Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RandomProjection)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Scaladoc link should be for BucketedRandomProjection now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for more details on the API.

{% include_example scala/org/apache/spark/examples/ml/RandomProjectionExample.scala %}
</div>

<div data-lang="java" markdown="1">

Refer to the [RandomProjection Java docs](api/java/org/apache/spark/ml/feature/RandomProjection.html)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for more details on the API.

{% include_example java/org/apache/spark/examples/ml/JavaRandomProjectionExample.java %}
</div>
</div>

## MinHash for Jaccard Distance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarity, not distance, right? it's higher when they overlap more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jaccard Distance is just 1 - JaccardSimilarity(https://en.wikipedia.org/wiki/Jaccard_index)

There are 2 reasons we use Jaccard Distance instead of similarity:
(1) It's cleaner to map each distance metric to their LSH family by the definition of LSH.
(2) In approxNearestNeighbor and approxSimilarityJoin, the returned dataset has a distCol showing the distance values.

[MinHash](https://en.wikipedia.org/wiki/MinHash) is the LSH family in `spark.ml` for Jaccard distance where input features are sets of natural numbers. Jaccard distance of two sets is defined by the cardinality of their intersection and union:
`\[
d(\mathbf{A}, \mathbf{B}) = 1 - \frac{|\mathbf{A} \cap \mathbf{B}|}{|\mathbf{A} \cup \mathbf{B}|}
\]`
As its LSH family, MinHash applies a random [perfect hash function](https://en.wikipedia.org/wiki/Perfect_hash_function) `g` to each elements in the set and take the minimum of all hashed values:
`\[
h(\mathbf{A}) = \min_{a \in \mathbf{A}}(g(a))
\]`

Input sets for MinHash is represented in vectors which dimension equals the total number of elements in the space. Each dimension of the vectors represents the status of an elements: zero value means the elements is not in the set; non-zero value means the set contains the corresponding elements. For example, `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])` means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same sort of comment: isn't the intuition that MinHash approximates the Jaccard similarity without actually computing it completely?

This may again reduce to my different understanding of the technique, but MinHash isn't a hashing technique per se. it relies on a given family of hash functions to approximate set similarity. I'm finding it a little hard to view it as an 'LSH family'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how MinHash approximates the Jaccard similarity? It is true that Pr(min(h(A)) = min(h(B))) is equal to Jaccard similarity when h is picked from a universal hash family. But I think we are not computing Pr(min(h(A)) = min(h(B))) in MinHash, we are only use this property to construct an LSH function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should mention this property, because it is not very intuitive and forms the basis of using MinHash to approximate Jacquard distance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can clarify here - "The input sets for MinHash are represented as binary vectors, where the vector indices represent the elements themselves and the non-zero values in the vector represent the presence of that element in the set. While both dense and sparse vectors are supported, typically sparse vectors are recommended for efficiency. For example, ..."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we require (check) in MinHash that the input vectors are binary? Or do we just treat any non-zero value as 1 effectively? Maybe mention it whichever it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


**Note:** Empty sets cannot be transformed by MinHash, which means any input vector must have at least 1 non-zero indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"... non-zero entry" perhaps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


<div class="codetabs">
<div data-lang="scala" markdown="1">

Refer to the [MinHash Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinHash)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be updated to MinHashLSH?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for more details on the API.

{% include_example scala/org/apache/spark/examples/ml/MinHashExample.scala %}
</div>

<div data-lang="java" markdown="1">

Refer to the [MinHash Java docs](api/java/org/apache/spark/ml/feature/MinHash.html)
for more details on the API.

{% include_example java/org/apache/spark/examples/ml/JavaMinHashExample.java %}
</div>
</div>

## Feature Transformation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really mention anything here about the amplification approach taken. Perhaps we can mention that we hash into multiple hash tables, which are used for OR-amplification in similarity join and ANN. We can mention what the impact is on accuracy / performance of numHashTables.

We can also mention that AND-amplification will be added in future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be good to mention that the transformed dataset can be cached, since transform can be expensive. We can either mention it here, or perhaps mention it (twice) in the join and ANN sections below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc is in L1509 and L1516

Feature Transformation is the base functionality to add hash results as a new column. Users can specify input column name and output column name by setting `inputCol` and `outputCol`. Also in `spark.ml`, all LSH families can pick multiple LSH hash functions. The number of hash functions could be set in `outputDim` in any LSH family.

<div class="codetabs">
<div data-lang="scala" markdown="1">

{% include_example scala/org/apache/spark/examples/ml/LSHTransformationExample.scala %}
</div>

<div data-lang="java" markdown="1">

{% include_example java/org/apache/spark/examples/ml/JavaLSHTransformationExample.java %}
</div>
</div>

When multiple hash functions are picked, it's very useful for users to apply [amplification](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Amplification) to trade off between false positive and false negative rate.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rate -> rate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I did not get it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not looked too much into the implementation of LSH, but this is a property of the queries, right? This should be moved into its own section along with some examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is removed since it will be fully implemented in https://issues.apache.org/jira/browse/SPARK-18450

* AND-amplifications: Two input vectors are defined to be in the same bucket only if ALL of the hash values match. This will decrease the false positive rate but increase the false negative rate.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to add some extra spaces and new lines to make the list work. Try it out in a web-based markdown renderer if necessary

* OR-amplifications: Two input vectors are defined to be in the same bucket as long as ANY one of the hash value matches. This will increase the false positive rate but decrease the false negative rate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will self join produce duplicates? If so we should note that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Approximate nearest neighbor and approximate similarity join use OR-amplification.

## Approximate Similarity Join
Approximate similarity join takes two datasets, and approximately returns row pairs which distance is smaller than a user-defined threshold. Approximate Similarity Join supports both joining two different datasets and self joining.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Row pairs of what? Does it return all columns or just the vector columns?

I think we need to be specific about "distance between two input vectors is smaller".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some description in L1501


Approximate similarity join allows users to cache the transformed columns when necessary: If the `outputCol` is missing, the method will transform the data; if the `outputCol` exists, it will use the `outputCol` directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's totally clear what this means. Let's be more specific about the steps involved:

  1. transform the input dataset(s) to create the hash signature in LSH.outputCol.
  2. if an untransformed dataset is used as input, it will be transformed automatically

Because (1) is expensive, the transformed dataset can be cached if it will be re-used many times.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about now?


<div class="codetabs">
<div data-lang="scala" markdown="1">

{% include_example scala/org/apache/spark/examples/ml/ApproxSimilarityJoinExample.scala %}
</div>

<div data-lang="java" markdown="1">

{% include_example java/org/apache/spark/examples/ml/JavaApproxSimilarityJoinExample.java %}
</div>
</div>

## Approximate Nearest Neighbor Search
Approximate nearest neighbor search takes a dataset and a key, and approximately returns a number of rows in the dataset that are closest to the key. The number of rows to return are defined by user.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can simplify to "... returns a specified number of rows ..." and drop the last sentence.

Are we supporting arbitrary keys? I don't think so, so perhaps just call it "vector"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Approximate nearest neighbor search allows users to cache the transformed columns when necessary: If the `outputCol` is missing, the method will transform the data; if the `outputCol` exists, it will use the `outputCol` directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above for similarity join applies here.


There are two methods of approximate nearest neighbor search implemented in `spark.ml`:
* Single probing search: Only the hash bucket(s) where the key is hashed are searched. This method is time efficient but might return less than k rows.
* Multi probing search: All nearby hash buckets are searched to ensure exactly k rows are returned when possible, but this method can take more time.
The example code will show the difference between single and multi probing.
<div class="codetabs">
<div data-lang="scala" markdown="1">

{% include_example scala/org/apache/spark/examples/ml/ApproxNearestNeighborExample.scala %}
</div>

<div data-lang="java" markdown="1">

{% include_example java/org/apache/spark/examples/ml/JavaApproxNearestNeighborExample.java %}
</div>
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import org.apache.spark.sql.SparkSession;

// $example on$
import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.feature.MinHash;
import org.apache.spark.ml.feature.MinHashModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off$

public class JavaApproxNearestNeighborExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaApproxNearestNeighborExample")
.getOrCreate();

// $example on$
List<Row> data = Arrays.asList(
RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0}))
);

StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("keys", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> dataFrame = spark.createDataFrame(data, schema);

MinHash mh = new MinHash()
.setOutputDim(5)
.setInputCol("keys")
.setOutputCol("values");

Vector key1 = Vectors.sparse(6, new int[]{1, 3}, new double[]{1.0, 1.0, 1.0});
Vector key2 = Vectors.sparse(6, new int[]{5}, new double[]{1.0, 1.0, 1.0});

MinHashModel model = mh.fit(dataFrame);
model.approxNearestNeighbors(dataFrame, key1, 2).show();

System.out.println("Difference between single probing and multi probing:");

System.out.println("Single probing sometimes returns less than k rows");
model.approxNearestNeighbors(dataFrame, key2, 3, true, "distCol").show();

System.out.println("Multi probing returns exact k rows whenever possible");
model.approxNearestNeighbors(dataFrame, key2, 3, false, "distCol").show();

System.out.println("Multi probing returns the whole dataset when there are not enough rows");
model.approxNearestNeighbors(dataFrame, key2, 4, false, "distCol").show();
// $example off$

spark.stop();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import org.apache.spark.sql.SparkSession;

// $example on$
import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.feature.MinHash;
import org.apache.spark.ml.feature.MinHashModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off$

public class JavaApproxSimilarityJoinExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaApproxNearestNeighborExample")
.getOrCreate();

// $example on$
List<Row> dataA = Arrays.asList(
RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0}))
);

List<Row> dataB = Arrays.asList(
RowFactory.create(0, Vectors.sparse(6, new int[]{1, 3, 5}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 5}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(2, Vectors.sparse(6, new int[]{1, 2, 4}, new double[]{1.0, 1.0, 1.0}))
);

StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("keys", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
Dataset<Row> dfB = spark.createDataFrame(dataB, schema);

MinHash mh = new MinHash()
.setOutputDim(5)
.setInputCol("keys")
.setOutputCol("values");

MinHashModel model = mh.fit(dfA);
model.approxSimilarityJoin(dfA, dfB, 0.6).show();

// Cache the transformed columns
Dataset<Row> transformedA = model.transform(dfA);
Dataset<Row> transformedB = model.transform(dfB);
model.approxSimilarityJoin(transformedA, transformedB, 0.6).show();

// Self Join
model.approxSimilarityJoin(dfA, dfA, 0.6).filter("datasetA.id < datasetB.id").show();
// $example off$

spark.stop();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import org.apache.spark.sql.SparkSession;

// $example on$
import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.feature.MinHash;
import org.apache.spark.ml.feature.MinHashModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off$

public class JavaLSHTransformationExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaLSHTransformationExample")
.getOrCreate();

// $example on$
List<Row> data = Arrays.asList(
RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0}))
);

StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("keys", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> dataFrame = spark.createDataFrame(data, schema);

// Single LSH hashing
MinHash mhSingleHash = new MinHash()
.setOutputDim(1)
.setInputCol("keys")
.setOutputCol("values");
MinHashModel modelSingleHash = mhSingleHash.fit(dataFrame);
// Feature transformation: add a new column for a hash value
modelSingleHash.transform(dataFrame).show();

// Use more than 1 hash functions
MinHash mh = new MinHash()
.setOutputDim(5)
.setInputCol("keys")
.setOutputCol("values");
MinHashModel model = mh.fit(dataFrame);
// Feature Transformation: add a new column for multiple hash values
model.transform(dataFrame).show();
// $example off$

spark.stop();
}
}
Loading