Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
68cebd0
ML Classification done
techaddict May 13, 2016
e8bae89
ml clustering done
techaddict May 13, 2016
afe620a
ML Feature done
techaddict May 13, 2016
f3a2244
ML Feature remove unused imports
techaddict May 13, 2016
9e41936
ML Param Done - Remove SparkSession, esc since not used
techaddict May 13, 2016
b97da5b
ml regression done
techaddict May 13, 2016
a96fa3a
ml libsvm done
techaddict May 13, 2016
29a1194
ml tuning, util done
techaddict May 13, 2016
8ad34a0
mllib classification
techaddict May 13, 2016
940d564
mllib clustering
techaddict May 13, 2016
9d4d015
mllib evaluation and feature
techaddict May 13, 2016
e033253
mllib fpm
techaddict May 13, 2016
3aa61df
mllib random
techaddict May 13, 2016
ddf68da
mllib recommendation
techaddict May 13, 2016
90b048a
mllib regression
techaddict May 13, 2016
c3f166d
mllib tree
techaddict May 13, 2016
36ce8d2
fix javastyle
techaddict May 13, 2016
12ba028
add license to SharedSparkSession
techaddict May 13, 2016
39cd94d
merge master
techaddict May 18, 2016
f3fa1f5
fix import
techaddict May 18, 2016
e4117f3
Mark custom methods as protected and add override
techaddict May 19, 2016
874eddb
fix java lint errors
techaddict May 19, 2016
8e264ea
fix style
techaddict May 19, 2016
92ae70b
Merge branch 'master' into SPARK-15296
techaddict May 19, 2016
40edaad
remove customSetUp() and customTearDown()
techaddict May 20, 2016
65307c5
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
9b340fe
SparkSession appName should be class simple name
techaddict May 20, 2016
622e28d
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
56de3e4
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
c1ce08e
SharedSparkSession should implement Serializable
techaddict May 20, 2016
2bcfdd7
fix imports
techaddict May 20, 2016
9d7a89c
Merge branch 'master' into SPARK-15296
techaddict May 20, 2016
138818b
fix
techaddict May 20, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class JavaGaussianMixtureExample {

public static void main(String[] args) {

// Creates a SparkSession
// Creates a SparkSession
SparkSession spark = SparkSession
.builder()
.appName("JavaGaussianMixtureExample")
Expand Down
48 changes: 48 additions & 0 deletions mllib/src/test/java/org/apache/spark/SharedSparkSession.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

import java.io.IOException;
import java.io.Serializable;

import org.junit.After;
import org.junit.Before;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;

public abstract class SharedSparkSession implements Serializable {

protected transient SparkSession spark;
protected transient JavaSparkContext jsc;

@Before
public void setUp() throws IOException {
spark = SparkSession.builder()
.master("local[2]")
.appName(getClass().getSimpleName())
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jsc needs to be stopped too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all jsc.stop will do is stop the underlying sparkContext which is anyways done by sparkSession.

}

@After
public void tearDown() {
spark.stop();
spark = null;
}
}
27 changes: 7 additions & 20 deletions mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,34 @@

package org.apache.spark.ml;

import org.junit.After;
import org.junit.Before;
import java.io.IOException;

import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

/**
* Test Pipeline construction and fitting in Java.
*/
public class JavaPipelineSuite {
public class JavaPipelineSuite extends SharedSparkSession {

private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaPipelineSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
@Override
public void setUp() throws IOException {
super.setUp();
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = spark.createDataFrame(points, LabeledPoint.class);
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void pipeline() {
StandardScaler scaler = new StandardScaler()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,19 @@

package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaDecisionTreeClassifierSuite implements Serializable {

private transient SparkSession spark;
private transient JavaSparkContext jsc;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaDecisionTreeClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaDecisionTreeClassifierSuite extends SharedSparkSession {

@Test
public void runDT() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,19 @@

package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;


public class JavaGBTClassifierSuite implements Serializable {

private transient SparkSession spark;
private transient JavaSparkContext jsc;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaGBTClassifierSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaGBTClassifierSuite extends SharedSparkSession {

@Test
public void runDT() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,36 @@

package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.io.IOException;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaLogisticRegressionSuite implements Serializable {
public class JavaLogisticRegressionSuite extends SharedSparkSession {

private transient SparkSession spark;
private transient JavaSparkContext jsc;
private transient Dataset<Row> dataset;

private transient JavaRDD<LabeledPoint> datasetRDD;
private double eps = 1e-5;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
jsc = new JavaSparkContext(spark.sparkContext());

@Override
public void setUp() throws IOException {
super.setUp();
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.createOrReplaceTempView("dataset");
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,19 @@

package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaMultilayerPerceptronClassifierSuite implements Serializable {

private transient SparkSession spark;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaMultilayerPerceptronClassifierSuite extends SharedSparkSession {

@Test
public void testMLPC() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,24 @@

package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class JavaNaiveBayesSuite implements Serializable {

private transient SparkSession spark;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local")
.appName("JavaLogisticRegressionSuite")
.getOrCreate();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}
public class JavaNaiveBayesSuite extends SharedSparkSession {

public void validatePrediction(Dataset<Row> predictionAndLabels) {
for (Row r : predictionAndLabels.collectAsList()) {
Expand Down
Loading