Skip to content

Commit 01cf649

Browse files
techaddictmengxr
authored andcommitted
[SPARK-15296][MLLIB] Refactor All Java Tests that use SparkSession
## What changes were proposed in this pull request? Refactor All Java Tests that use SparkSession, to extend SharedSparkSesion ## How was this patch tested? Existing Tests Author: Sandeep Singh <[email protected]> Closes apache#13101 from techaddict/SPARK-15296.
1 parent 16ba71a commit 01cf649

File tree

59 files changed

+207
-1148
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+207
-1148
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public class JavaGaussianMixtureExample {
3737

3838
public static void main(String[] args) {
3939

40-
// Creates a SparkSession
40+
// Creates a SparkSession
4141
SparkSession spark = SparkSession
4242
.builder()
4343
.appName("JavaGaussianMixtureExample")
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark;
19+
20+
import java.io.IOException;
21+
import java.io.Serializable;
22+
23+
import org.junit.After;
24+
import org.junit.Before;
25+
26+
import org.apache.spark.api.java.JavaSparkContext;
27+
import org.apache.spark.sql.SparkSession;
28+
29+
public abstract class SharedSparkSession implements Serializable {
30+
31+
protected transient SparkSession spark;
32+
protected transient JavaSparkContext jsc;
33+
34+
@Before
35+
public void setUp() throws IOException {
36+
spark = SparkSession.builder()
37+
.master("local[2]")
38+
.appName(getClass().getSimpleName())
39+
.getOrCreate();
40+
jsc = new JavaSparkContext(spark.sparkContext());
41+
}
42+
43+
@After
44+
public void tearDown() {
45+
spark.stop();
46+
spark = null;
47+
}
48+
}

mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,34 @@
1717

1818
package org.apache.spark.ml;
1919

20-
import org.junit.After;
21-
import org.junit.Before;
20+
import java.io.IOException;
21+
2222
import org.junit.Test;
2323

24+
import org.apache.spark.SharedSparkSession;
2425
import org.apache.spark.api.java.JavaRDD;
25-
import org.apache.spark.api.java.JavaSparkContext;
2626
import org.apache.spark.ml.classification.LogisticRegression;
2727
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
2828
import org.apache.spark.ml.feature.LabeledPoint;
2929
import org.apache.spark.ml.feature.StandardScaler;
3030
import org.apache.spark.sql.Dataset;
3131
import org.apache.spark.sql.Row;
32-
import org.apache.spark.sql.SparkSession;
3332

3433
/**
3534
* Test Pipeline construction and fitting in Java.
3635
*/
37-
public class JavaPipelineSuite {
36+
public class JavaPipelineSuite extends SharedSparkSession {
3837

39-
private transient SparkSession spark;
40-
private transient JavaSparkContext jsc;
4138
private transient Dataset<Row> dataset;
4239

43-
@Before
44-
public void setUp() {
45-
spark = SparkSession.builder()
46-
.master("local")
47-
.appName("JavaPipelineSuite")
48-
.getOrCreate();
49-
jsc = new JavaSparkContext(spark.sparkContext());
40+
@Override
41+
public void setUp() throws IOException {
42+
super.setUp();
5043
JavaRDD<LabeledPoint> points =
5144
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
5245
dataset = spark.createDataFrame(points, LabeledPoint.class);
5346
}
5447

55-
@After
56-
public void tearDown() {
57-
spark.stop();
58-
spark = null;
59-
}
60-
6148
@Test
6249
public void pipeline() {
6350
StandardScaler scaler = new StandardScaler()

mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,19 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import java.io.Serializable;
2120
import java.util.HashMap;
2221
import java.util.Map;
2322

24-
import org.junit.After;
25-
import org.junit.Before;
2623
import org.junit.Test;
2724

25+
import org.apache.spark.SharedSparkSession;
2826
import org.apache.spark.api.java.JavaRDD;
29-
import org.apache.spark.api.java.JavaSparkContext;
30-
import org.apache.spark.ml.classification.LogisticRegressionSuite;
3127
import org.apache.spark.ml.feature.LabeledPoint;
3228
import org.apache.spark.ml.tree.impl.TreeTests;
3329
import org.apache.spark.sql.Dataset;
3430
import org.apache.spark.sql.Row;
35-
import org.apache.spark.sql.SparkSession;
3631

37-
public class JavaDecisionTreeClassifierSuite implements Serializable {
38-
39-
private transient SparkSession spark;
40-
private transient JavaSparkContext jsc;
41-
42-
@Before
43-
public void setUp() {
44-
spark = SparkSession.builder()
45-
.master("local")
46-
.appName("JavaDecisionTreeClassifierSuite")
47-
.getOrCreate();
48-
jsc = new JavaSparkContext(spark.sparkContext());
49-
}
50-
51-
@After
52-
public void tearDown() {
53-
spark.stop();
54-
spark = null;
55-
}
32+
public class JavaDecisionTreeClassifierSuite extends SharedSparkSession {
5633

5734
@Test
5835
public void runDT() {

mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,19 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import java.io.Serializable;
2120
import java.util.HashMap;
2221
import java.util.Map;
2322

24-
import org.junit.After;
25-
import org.junit.Before;
2623
import org.junit.Test;
2724

25+
import org.apache.spark.SharedSparkSession;
2826
import org.apache.spark.api.java.JavaRDD;
29-
import org.apache.spark.api.java.JavaSparkContext;
30-
import org.apache.spark.ml.classification.LogisticRegressionSuite;
3127
import org.apache.spark.ml.feature.LabeledPoint;
3228
import org.apache.spark.ml.tree.impl.TreeTests;
3329
import org.apache.spark.sql.Dataset;
3430
import org.apache.spark.sql.Row;
35-
import org.apache.spark.sql.SparkSession;
3631

37-
38-
public class JavaGBTClassifierSuite implements Serializable {
39-
40-
private transient SparkSession spark;
41-
private transient JavaSparkContext jsc;
42-
43-
@Before
44-
public void setUp() {
45-
spark = SparkSession.builder()
46-
.master("local")
47-
.appName("JavaGBTClassifierSuite")
48-
.getOrCreate();
49-
jsc = new JavaSparkContext(spark.sparkContext());
50-
}
51-
52-
@After
53-
public void tearDown() {
54-
spark.stop();
55-
spark = null;
56-
}
32+
public class JavaGBTClassifierSuite extends SharedSparkSession {
5733

5834
@Test
5935
public void runDT() {

mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,36 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import java.io.Serializable;
20+
import java.io.IOException;
2121
import java.util.List;
2222

23-
import org.junit.After;
2423
import org.junit.Assert;
25-
import org.junit.Before;
2624
import org.junit.Test;
2725

26+
import org.apache.spark.SharedSparkSession;
2827
import org.apache.spark.api.java.JavaRDD;
29-
import org.apache.spark.api.java.JavaSparkContext;
3028
import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
3129
import org.apache.spark.ml.feature.LabeledPoint;
3230
import org.apache.spark.ml.linalg.Vector;
3331
import org.apache.spark.sql.Dataset;
3432
import org.apache.spark.sql.Row;
35-
import org.apache.spark.sql.SparkSession;
3633

37-
public class JavaLogisticRegressionSuite implements Serializable {
34+
public class JavaLogisticRegressionSuite extends SharedSparkSession {
3835

39-
private transient SparkSession spark;
40-
private transient JavaSparkContext jsc;
4136
private transient Dataset<Row> dataset;
4237

4338
private transient JavaRDD<LabeledPoint> datasetRDD;
4439
private double eps = 1e-5;
4540

46-
@Before
47-
public void setUp() {
48-
spark = SparkSession.builder()
49-
.master("local")
50-
.appName("JavaLogisticRegressionSuite")
51-
.getOrCreate();
52-
jsc = new JavaSparkContext(spark.sparkContext());
53-
41+
@Override
42+
public void setUp() throws IOException {
43+
super.setUp();
5444
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
5545
datasetRDD = jsc.parallelize(points, 2);
5646
dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
5747
dataset.createOrReplaceTempView("dataset");
5848
}
5949

60-
@After
61-
public void tearDown() {
62-
spark.stop();
63-
spark = null;
64-
}
65-
6650
@Test
6751
public void logisticRegressionDefaultParams() {
6852
LogisticRegression lr = new LogisticRegression();

mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,19 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import java.io.Serializable;
2120
import java.util.Arrays;
2221
import java.util.List;
2322

24-
import org.junit.After;
2523
import org.junit.Assert;
26-
import org.junit.Before;
2724
import org.junit.Test;
2825

26+
import org.apache.spark.SharedSparkSession;
2927
import org.apache.spark.ml.feature.LabeledPoint;
3028
import org.apache.spark.ml.linalg.Vectors;
3129
import org.apache.spark.sql.Dataset;
3230
import org.apache.spark.sql.Row;
33-
import org.apache.spark.sql.SparkSession;
3431

35-
public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
36-
37-
private transient SparkSession spark;
38-
39-
@Before
40-
public void setUp() {
41-
spark = SparkSession.builder()
42-
.master("local")
43-
.appName("JavaLogisticRegressionSuite")
44-
.getOrCreate();
45-
}
46-
47-
@After
48-
public void tearDown() {
49-
spark.stop();
50-
spark = null;
51-
}
32+
public class JavaMultilayerPerceptronClassifierSuite extends SharedSparkSession {
5233

5334
@Test
5435
public void testMLPC() {

mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,24 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import java.io.Serializable;
2120
import java.util.Arrays;
2221
import java.util.List;
2322

24-
import org.junit.After;
25-
import org.junit.Before;
2623
import org.junit.Test;
2724
import static org.junit.Assert.assertEquals;
2825

26+
import org.apache.spark.SharedSparkSession;
2927
import org.apache.spark.ml.linalg.VectorUDT;
3028
import org.apache.spark.ml.linalg.Vectors;
3129
import org.apache.spark.sql.Dataset;
3230
import org.apache.spark.sql.Row;
3331
import org.apache.spark.sql.RowFactory;
34-
import org.apache.spark.sql.SparkSession;
3532
import org.apache.spark.sql.types.DataTypes;
3633
import org.apache.spark.sql.types.Metadata;
3734
import org.apache.spark.sql.types.StructField;
3835
import org.apache.spark.sql.types.StructType;
3936

40-
public class JavaNaiveBayesSuite implements Serializable {
41-
42-
private transient SparkSession spark;
43-
44-
@Before
45-
public void setUp() {
46-
spark = SparkSession.builder()
47-
.master("local")
48-
.appName("JavaLogisticRegressionSuite")
49-
.getOrCreate();
50-
}
51-
52-
@After
53-
public void tearDown() {
54-
spark.stop();
55-
spark = null;
56-
}
37+
public class JavaNaiveBayesSuite extends SharedSparkSession {
5738

5839
public void validatePrediction(Dataset<Row> predictionAndLabels) {
5940
for (Row r : predictionAndLabels.collectAsList()) {

0 commit comments

Comments
 (0)