Skip to content

Commit 8e4665d

Browse files
committed
Added tree.py to python doc tests. Fixed bug from missing categoricalFeaturesInfo argument.
1 parent b7b2922 commit 8e4665d

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

python/pyspark/mllib/tree.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class DecisionTree(object):
8888
It will probably be modified for Spark v1.2.
8989
9090
Example usage:
91-
>>> from numpy import array, ndarray
91+
>>> from numpy import array
9292
>>> from pyspark.mllib.regression import LabeledPoint
9393
>>> from pyspark.mllib.tree import DecisionTree
9494
>>> from pyspark.mllib.linalg import SparseVector
@@ -99,8 +99,9 @@ class DecisionTree(object):
9999
... LabeledPoint(1.0, [2.0]),
100100
... LabeledPoint(1.0, [3.0])
101101
... ]
102-
>>>
103-
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
102+
>>> categoricalFeaturesInfo = {} # no categorical features
103+
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
104+
... categoricalFeaturesInfo=categoricalFeaturesInfo)
104105
>>> print(model)
105106
DecisionTreeModel classifier
106107
If (feature 0 <= 0.5)
@@ -119,7 +120,8 @@ class DecisionTree(object):
119120
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
120121
... ]
121122
>>>
122-
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
123+
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
124+
... categoricalFeaturesInfo=categoricalFeaturesInfo)
123125
>>> model.predict(array([0.0, 1.0])) == 1
124126
True
125127
>>> model.predict(array([0.0, 0.0])) == 0

python/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py"
7979
run_test "pyspark/mllib/recommendation.py"
8080
run_test "pyspark/mllib/regression.py"
8181
run_test "pyspark/mllib/tests.py"
82+
run_test "pyspark/mllib/tree.py"
8283
run_test "pyspark/mllib/util.py"
8384

8485
if [[ $FAILED == 0 ]]; then

0 commit comments

Comments
 (0)