Skip to content
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
68cebd0
ML Classification done
techaddict May 13, 2016
e8bae89
ml clustering done
techaddict May 13, 2016
afe620a
ML Feature done
techaddict May 13, 2016
f3a2244
ML Feature remove unused imports
techaddict May 13, 2016
9e41936
ML Param Done - Remove SparkSession, esc since not used
techaddict May 13, 2016
b97da5b
ml regression done
techaddict May 13, 2016
a96fa3a
ml libsvm done
techaddict May 13, 2016
29a1194
ml tuning, util done
techaddict May 13, 2016
8ad34a0
mllib classification
techaddict May 13, 2016
940d564
mllib clustering
techaddict May 13, 2016
9d4d015
mllib evaluation and feature
techaddict May 13, 2016
e033253
mllib fpm
techaddict May 13, 2016
3aa61df
mllib random
techaddict May 13, 2016
ddf68da
mllib recommendation
techaddict May 13, 2016
90b048a
mllib regression
techaddict May 13, 2016
c3f166d
mllib tree
techaddict May 13, 2016
36ce8d2
fix javastyle
techaddict May 13, 2016
12ba028
add license to SharedSparkSession
techaddict May 13, 2016
39cd94d
merge master
techaddict May 18, 2016
f3fa1f5
fix import
techaddict May 18, 2016
e4117f3
Mark custom methods as protected and add override
techaddict May 19, 2016
874eddb
fix java lint errors
techaddict May 19, 2016
8e264ea
fix style
techaddict May 19, 2016
92ae70b
Merge branch 'master' into SPARK-15296
techaddict May 19, 2016
40edaad
remove customSetUp() and customTearDown()
techaddict May 20, 2016
65307c5
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
9b340fe
SparkSession appName should be class simple name
techaddict May 20, 2016
622e28d
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
56de3e4
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
c1ce08e
SharedSparkSession should implement Serializable
techaddict May 20, 2016
2bcfdd7
fix imports
techaddict May 20, 2016
9d7a89c
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
138818b
fix
techaddict May 20, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class JavaGaussianMixtureExample {

public static void main(String[] args) {

// Creates a SparkSession
// Creates a SparkSession
SparkSession spark = SparkSession
.builder()
.appName("JavaGaussianMixtureExample")
Expand Down
47 changes: 47 additions & 0 deletions mllib/src/test/java/org/apache/spark/SharedSparkSession.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

import java.io.IOException;

import org.junit.After;
import org.junit.Before;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;

public abstract class SharedSparkSession {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we include implements Serializable to save some code?


public transient SparkSession spark;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: This should be protected.

public transient JavaSparkContext jsc;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.


@Before
public void setUp() throws IOException {
spark = SparkSession.builder()
.master("local")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests depend on the default number of partitions. So it would be better to say local[2] instead of local.

.appName("shared-spark-session")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can set the app name to the name of the test class with getClass.getSimpleName

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jsc needs to be stopped too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all jsc.stop will do is stop the underlying sparkContext which is anyways done by sparkSession.

}

@After
public void tearDown() {
spark.stop();
spark = null;
}
}
27 changes: 7 additions & 20 deletions mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,34 @@

package org.apache.spark.ml;

import org.junit.After;
import org.junit.Before;
import java.io.IOException;

import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

/**
* Test Pipeline construction and fitting in Java.
*/
public class JavaPipelineSuite {
public class JavaPipelineSuite extends SharedSparkSession {

private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaPipelineSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
@Override
public void setUp() throws IOException {
super.setUp();
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = spark.createDataFrame(points, LabeledPoint.class);
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void pipeline() {
StandardScaler scaler = new StandardScaler()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,16 @@
import java.util.HashMap;
import java.util.Map;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaDecisionTreeClassifierSuite implements Serializable {

private transient SparkSession spark;
private transient JavaSparkContext jsc;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaDecisionTreeClassifierSuite extends SharedSparkSession implements Serializable {

@Test
public void runDT() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,16 @@
import java.util.HashMap;
import java.util.Map;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;


public class JavaGBTClassifierSuite implements Serializable {

private transient SparkSession spark;
private transient JavaSparkContext jsc;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaGBTClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaGBTClassifierSuite extends SharedSparkSession implements Serializable {

@Test
public void runDT() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,37 @@

package org.apache.spark.ml.classification;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaLogisticRegressionSuite implements Serializable {
public class JavaLogisticRegressionSuite extends SharedSparkSession implements Serializable {

private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;

private transient JavaRDD<LabeledPoint> datasetRDD;
private double eps = 1e-5;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());

@Override
public void setUp() throws IOException {
super.setUp();
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.createOrReplaceTempView("dataset");
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,17 @@
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaMultilayerPerceptronClassifierSuite implements Serializable {

private transient SparkSession spark;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaMultilayerPerceptronClassifierSuite
extends SharedSparkSession implements Serializable {

@Test
public void testMLPC() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,21 @@
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class JavaNaiveBayesSuite implements Serializable {

private transient SparkSession spark;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaNaiveBayesSuite extends SharedSparkSession implements Serializable {

public void validatePrediction(Dataset<Row> predictionAndLabels) {
for (Row r : predictionAndLabels.collectAsList()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,30 @@

package org.apache.spark.ml.classification;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;

import scala.collection.JavaConverters;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.feature.LabeledPoint;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;

public class JavaOneVsRestSuite implements Serializable {
public class JavaOneVsRestSuite extends SharedSparkSession implements Serializable {

private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLOneVsRestSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());

@Override
public void setUp() throws IOException {
super.setUp();
int nPoints = 3;

// The following coefficients and xMean/xVariance are computed from iris dataset with
Expand All @@ -68,12 +59,6 @@ public void setUp() {
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void oneVsRestDefaultParams() {
OneVsRest ova = new OneVsRest();
Expand Down
Loading