@@ -88,7 +88,7 @@ class DecisionTree(object):
8888 It will probably be modified for Spark v1.2.
8989
9090 Example usage:
91- >>> from numpy import array, ndarray
91+ >>> from numpy import array
9292 >>> from pyspark.mllib.regression import LabeledPoint
9393 >>> from pyspark.mllib.tree import DecisionTree
9494 >>> from pyspark.mllib.linalg import SparseVector
@@ -99,8 +99,9 @@ class DecisionTree(object):
9999 ... LabeledPoint(1.0, [2.0]),
100100 ... LabeledPoint(1.0, [3.0])
101101 ... ]
102- >>>
103- >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
102+ >>> categoricalFeaturesInfo = {} # no categorical features
103+ >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
104+ ... categoricalFeaturesInfo=categoricalFeaturesInfo)
104105 >>> print(model)
105106 DecisionTreeModel classifier
106107 If (feature 0 <= 0.5)
@@ -119,7 +120,8 @@ class DecisionTree(object):
119120 ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
120121 ... ]
121122 >>>
122- >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
123+ >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
124+ ... categoricalFeaturesInfo=categoricalFeaturesInfo)
123125 >>> model.predict(array([0.0, 1.0])) == 1
124126 True
125127 >>> model.predict(array([0.0, 0.0])) == 0
0 commit comments