Skip to content

Commit 5952bdb

Browse files
vectorijkmengxr
authored andcommitted
[SPARK-10688] [ML] [PYSPARK] Python API for AFTSurvivalRegression
Implement Python API for AFTSurvivalRegression Author: vectorijk <jiangkai@gmail.com> Closes #8926 from vectorijk/spark-10688.
1 parent e978360 commit 5952bdb

1 file changed

Lines changed: 169 additions & 2 deletions

File tree

python/pyspark/ml/regression.py

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from pyspark.mllib.common import inherit_doc
2323

2424

25-
__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor',
26-
'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel',
25+
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
26+
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
27+
'GBTRegressor', 'GBTRegressionModel',
28+
'LinearRegression', 'LinearRegressionModel',
2729
'RandomForestRegressor', 'RandomForestRegressionModel']
2830

2931

@@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels):
609611
"""
610612

611613

614+
@inherit_doc
615+
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
616+
HasFitIntercept, HasMaxIter, HasTol):
617+
"""
618+
Accelerated Failure Time (AFT) Model Survival Regression
619+
620+
Fit a parametric AFT survival regression model based on the Weibull distribution
621+
of the survival time.
622+
623+
.. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
624+
625+
>>> from pyspark.mllib.linalg import Vectors
626+
>>> df = sqlContext.createDataFrame([
627+
... (1.0, Vectors.dense(1.0), 1.0),
628+
... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
629+
>>> aftsr = AFTSurvivalRegression()
630+
>>> model = aftsr.fit(df)
631+
>>> model.predict(Vectors.dense(6.3))
632+
1.0
633+
>>> model.predictQuantiles(Vectors.dense(6.3))
634+
DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
635+
>>> model.transform(df).show()
636+
+-----+---------+------+----------+
637+
|label| features|censor|prediction|
638+
+-----+---------+------+----------+
639+
| 1.0| [1.0]| 1.0| 1.0|
640+
| 0.0|(1,[],[])| 0.0| 1.0|
641+
+-----+---------+------+----------+
642+
...
643+
644+
.. versionadded:: 1.6.0
645+
"""
646+
647+
# a placeholder to make it appear in the generated doc
648+
censorCol = Param(Params._dummy(), "censorCol",
649+
"censor column name. The value of this column could be 0 or 1. " +
650+
"If the value is 1, it means the event has occurred i.e. " +
651+
"uncensored; otherwise censored.")
652+
quantileProbabilities = \
653+
Param(Params._dummy(), "quantileProbabilities",
654+
"quantile probabilities array. Values of the quantile probabilities array " +
655+
"should be in the range (0, 1) and the array should be non-empty.")
656+
quantilesCol = Param(Params._dummy(), "quantilesCol",
657+
"quantiles column name. This column will output quantiles of " +
658+
"corresponding quantileProbabilities if it is set.")
659+
660+
@keyword_only
661+
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
662+
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
663+
quantileProbabilities=None, quantilesCol=None):
664+
"""
665+
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
666+
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
667+
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
668+
quantilesCol=None):
669+
"""
670+
super(AFTSurvivalRegression, self).__init__()
671+
self._java_obj = self._new_java_obj(
672+
"org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
673+
#: Param for censor column name
674+
self.censorCol = Param(self, "censorCol",
675+
"censor column name. The value of this column could be 0 or 1. " +
676+
"If the value is 1, it means the event has occurred i.e. " +
677+
"uncensored; otherwise censored.")
678+
#: Param for quantile probabilities array
679+
self.quantileProbabilities = \
680+
Param(self, "quantileProbabilities",
681+
"quantile probabilities array. Values of the quantile probabilities array " +
682+
"should be in the range (0, 1) and the array should be non-empty.")
683+
#: Param for quantiles column name
684+
self.quantilesCol = Param(self, "quantilesCol",
685+
"quantiles column name. This column will output quantiles of " +
686+
"corresponding quantileProbabilities if it is set.")
687+
self._setDefault(censorCol="censor",
688+
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
689+
kwargs = self.__init__._input_kwargs
690+
self.setParams(**kwargs)
691+
692+
@keyword_only
693+
@since("1.6.0")
694+
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
695+
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
696+
quantileProbabilities=None, quantilesCol=None):
697+
"""
698+
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
699+
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
700+
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
701+
quantilesCol=None):
702+
"""
703+
kwargs = self.setParams._input_kwargs
704+
if quantileProbabilities is None:
705+
return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5,
706+
0.75, 0.9, 0.95, 0.99])
707+
else:
708+
return self._set(**kwargs)
709+
710+
def _create_model(self, java_model):
711+
return AFTSurvivalRegressionModel(java_model)
712+
713+
@since("1.6.0")
714+
def setCensorCol(self, value):
715+
"""
716+
Sets the value of :py:attr:`censorCol`.
717+
"""
718+
self._paramMap[self.censorCol] = value
719+
return self
720+
721+
@since("1.6.0")
722+
def getCensorCol(self):
723+
"""
724+
Gets the value of censorCol or its default value.
725+
"""
726+
return self.getOrDefault(self.censorCol)
727+
728+
@since("1.6.0")
729+
def setQuantileProbabilities(self, value):
730+
"""
731+
Sets the value of :py:attr:`quantileProbabilities`.
732+
"""
733+
self._paramMap[self.quantileProbabilities] = value
734+
return self
735+
736+
@since("1.6.0")
737+
def getQuantileProbabilities(self):
738+
"""
739+
Gets the value of quantileProbabilities or its default value.
740+
"""
741+
return self.getOrDefault(self.quantileProbabilities)
742+
743+
@since("1.6.0")
744+
def setQuantilesCol(self, value):
745+
"""
746+
Sets the value of :py:attr:`quantilesCol`.
747+
"""
748+
self._paramMap[self.quantilesCol] = value
749+
return self
750+
751+
@since("1.6.0")
752+
def getQuantilesCol(self):
753+
"""
754+
Gets the value of quantilesCol or its default value.
755+
"""
756+
return self.getOrDefault(self.quantilesCol)
757+
758+
759+
class AFTSurvivalRegressionModel(JavaModel):
760+
"""
761+
Model fitted by AFTSurvivalRegression.
762+
763+
.. versionadded:: 1.6.0
764+
"""
765+
766+
def predictQuantiles(self, features):
767+
"""
768+
Predicted Quantiles
769+
"""
770+
return self._call_java("predictQuantiles", features)
771+
772+
def predict(self, features):
773+
"""
774+
Predicted value
775+
"""
776+
return self._call_java("predict", features)
777+
778+
612779
if __name__ == "__main__":
613780
import doctest
614781
from pyspark.context import SparkContext

0 commit comments

Comments
 (0)