Skip to content

Commit db808a1

Browse files
committed
update JavaLR example
1 parent befa592 commit db808a1

1 file changed

Lines changed: 4 additions & 10 deletions

File tree

  • examples/src/main/java/org/apache/spark/mllib/examples

examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@
1717

1818
package org.apache.spark.mllib.examples;
1919

20+
import java.util.regex.Pattern;
2021

2122
import org.apache.spark.api.java.JavaRDD;
2223
import org.apache.spark.api.java.JavaSparkContext;
2324
import org.apache.spark.api.java.function.Function;
2425

2526
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
2627
import org.apache.spark.mllib.classification.LogisticRegressionModel;
28+
import org.apache.spark.mllib.linalg.Vectors;
2729
import org.apache.spark.mllib.regression.LabeledPoint;
2830

29-
import java.util.Arrays;
30-
import java.util.regex.Pattern;
31-
3231
/**
3332
* Logistic regression based classification using ML Lib.
3433
*/
@@ -47,14 +46,10 @@ public LabeledPoint call(String line) {
4746
for (int i = 0; i < tok.length; ++i) {
4847
x[i] = Double.parseDouble(tok[i]);
4948
}
50-
return new LabeledPoint(y, x);
49+
return new LabeledPoint(y, Vectors.dense(x));
5150
}
5251
}
5352

54-
public static void printWeights(double[] a) {
55-
System.out.println(Arrays.toString(a));
56-
}
57-
5853
public static void main(String[] args) {
5954
if (args.length != 4) {
6055
System.err.println("Usage: JavaLR <master> <input_dir> <step_size> <niters>");
@@ -80,8 +75,7 @@ public static void main(String[] args) {
8075
LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(),
8176
iterations, stepSize);
8277

83-
System.out.print("Final w: ");
84-
printWeights(model.weights());
78+
System.out.print("Final w: " + model.weights());
8579

8680
System.exit(0);
8781
}

0 commit comments

Comments
 (0)