|
22 | 22 | from pyspark.mllib.common import inherit_doc |
23 | 23 |
|
24 | 24 |
|
25 | | -__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', |
26 | | - 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel', |
| 25 | +__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', |
| 26 | + 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', |
| 27 | + 'GBTRegressor', 'GBTRegressionModel', |
| 28 | + 'LinearRegression', 'LinearRegressionModel', |
27 | 29 | 'RandomForestRegressor', 'RandomForestRegressionModel'] |
28 | 30 |
|
29 | 31 |
|
@@ -609,6 +611,171 @@ class GBTRegressionModel(TreeEnsembleModels): |
609 | 611 | """ |
610 | 612 |
|
611 | 613 |
|
| 614 | +@inherit_doc |
| 615 | +class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, |
| 616 | + HasFitIntercept, HasMaxIter, HasTol): |
| 617 | + """ |
| 618 | + Accelerated Failure Time (AFT) Model Survival Regression |
| 619 | +
|
| 620 | + Fit a parametric AFT survival regression model based on the Weibull distribution |
| 621 | + of the survival time. |
| 622 | +
|
| 623 | + .. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_ |
| 624 | +
|
| 625 | + >>> from pyspark.mllib.linalg import Vectors |
| 626 | + >>> df = sqlContext.createDataFrame([ |
| 627 | + ... (1.0, Vectors.dense(1.0), 1.0), |
| 628 | + ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) |
| 629 | + >>> aftsr = AFTSurvivalRegression() |
| 630 | + >>> model = aftsr.fit(df) |
| 631 | + >>> model.predict(Vectors.dense(6.3)) |
| 632 | + 1.0 |
| 633 | + >>> model.predictQuantiles(Vectors.dense(6.3)) |
| 634 | + DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052]) |
| 635 | + >>> model.transform(df).show() |
| 636 | + +-----+---------+------+----------+ |
| 637 | + |label| features|censor|prediction| |
| 638 | + +-----+---------+------+----------+ |
| 639 | + | 1.0| [1.0]| 1.0| 1.0| |
| 640 | + | 0.0|(1,[],[])| 0.0| 1.0| |
| 641 | + +-----+---------+------+----------+ |
| 642 | + ... |
| 643 | +
|
| 644 | + .. versionadded:: 1.6.0 |
| 645 | + """ |
| 646 | + |
| 647 | + # a placeholder to make it appear in the generated doc |
| 648 | + censorCol = Param(Params._dummy(), "censorCol", |
| 649 | + "censor column name. The value of this column could be 0 or 1. " + |
| 650 | + "If the value is 1, it means the event has occurred i.e. " + |
| 651 | + "uncensored; otherwise censored.") |
| 652 | + quantileProbabilities = \ |
| 653 | + Param(Params._dummy(), "quantileProbabilities", |
| 654 | + "quantile probabilities array. Values of the quantile probabilities array " + |
| 655 | + "should be in the range (0, 1) and the array should be non-empty.") |
| 656 | + quantilesCol = Param(Params._dummy(), "quantilesCol", |
| 657 | + "quantiles column name. This column will output quantiles of " + |
| 658 | + "corresponding quantileProbabilities if it is set.") |
| 659 | + |
| 660 | + @keyword_only |
| 661 | + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", |
| 662 | + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", |
| 663 | + quantileProbabilities=None, quantilesCol=None): |
| 664 | + """ |
| 665 | + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ |
| 666 | + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ |
| 667 | + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ |
| 668 | + quantilesCol=None): |
| 669 | + """ |
| 670 | + super(AFTSurvivalRegression, self).__init__() |
| 671 | + self._java_obj = self._new_java_obj( |
| 672 | + "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) |
| 673 | + #: Param for censor column name |
| 674 | + self.censorCol = Param(self, "censorCol", |
| 675 | + "censor column name. The value of this column could be 0 or 1. " + |
| 676 | + "If the value is 1, it means the event has occurred i.e. " + |
| 677 | + "uncensored; otherwise censored.") |
| 678 | + #: Param for quantile probabilities array |
| 679 | + self.quantileProbabilities = \ |
| 680 | + Param(self, "quantileProbabilities", |
| 681 | + "quantile probabilities array. Values of the quantile probabilities array " + |
| 682 | + "should be in the range (0, 1) and the array should be non-empty.") |
| 683 | + #: Param for quantiles column name |
| 684 | + self.quantilesCol = Param(self, "quantilesCol", |
| 685 | + "quantiles column name. This column will output quantiles of " + |
| 686 | + "corresponding quantileProbabilities if it is set.") |
| 687 | + self._setDefault(censorCol="censor", |
| 688 | + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) |
| 689 | + kwargs = self.__init__._input_kwargs |
| 690 | + self.setParams(**kwargs) |
| 691 | + |
| 692 | + @keyword_only |
| 693 | + @since("1.6.0") |
| 694 | + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", |
| 695 | + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", |
| 696 | + quantileProbabilities=None, quantilesCol=None): |
| 697 | + """ |
| 698 | + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ |
| 699 | + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ |
| 700 | + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ |
| 701 | + quantilesCol=None): |
| 702 | + """ |
| 703 | + kwargs = self.setParams._input_kwargs |
| 704 | + if quantileProbabilities is None: |
| 705 | + return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5, |
| 706 | + 0.75, 0.9, 0.95, 0.99]) |
| 707 | + else: |
| 708 | + return self._set(**kwargs) |
| 709 | + |
| 710 | + def _create_model(self, java_model): |
| 711 | + return AFTSurvivalRegressionModel(java_model) |
| 712 | + |
| 713 | + @since("1.6.0") |
| 714 | + def setCensorCol(self, value): |
| 715 | + """ |
| 716 | + Sets the value of :py:attr:`censorCol`. |
| 717 | + """ |
| 718 | + self._paramMap[self.censorCol] = value |
| 719 | + return self |
| 720 | + |
| 721 | + @since("1.6.0") |
| 722 | + def getCensorCol(self): |
| 723 | + """ |
| 724 | + Gets the value of censorCol or its default value. |
| 725 | + """ |
| 726 | + return self.getOrDefault(self.censorCol) |
| 727 | + |
| 728 | + @since("1.6.0") |
| 729 | + def setQuantileProbabilities(self, value): |
| 730 | + """ |
| 731 | + Sets the value of :py:attr:`quantileProbabilities`. |
| 732 | + """ |
| 733 | + self._paramMap[self.quantileProbabilities] = value |
| 734 | + return self |
| 735 | + |
| 736 | + @since("1.6.0") |
| 737 | + def getQuantileProbabilities(self): |
| 738 | + """ |
| 739 | + Gets the value of quantileProbabilities or its default value. |
| 740 | + """ |
| 741 | + return self.getOrDefault(self.quantileProbabilities) |
| 742 | + |
| 743 | + @since("1.6.0") |
| 744 | + def setQuantilesCol(self, value): |
| 745 | + """ |
| 746 | + Sets the value of :py:attr:`quantilesCol`. |
| 747 | + """ |
| 748 | + self._paramMap[self.quantilesCol] = value |
| 749 | + return self |
| 750 | + |
| 751 | + @since("1.6.0") |
| 752 | + def getQuantilesCol(self): |
| 753 | + """ |
| 754 | + Gets the value of quantilesCol or its default value. |
| 755 | + """ |
| 756 | + return self.getOrDefault(self.quantilesCol) |
| 757 | + |
| 758 | + |
| 759 | +class AFTSurvivalRegressionModel(JavaModel): |
| 760 | + """ |
| 761 | + Model fitted by AFTSurvivalRegression. |
| 762 | +
|
| 763 | + .. versionadded:: 1.6.0 |
| 764 | + """ |
| 765 | + |
| 766 | + def predictQuantiles(self, features): |
| 767 | + """ |
| 768 | + Predicted Quantiles |
| 769 | + """ |
| 770 | + return self._call_java("predictQuantiles", features) |
| 771 | + |
| 772 | + def predict(self, features): |
| 773 | + """ |
| 774 | + Predicted value |
| 775 | + """ |
| 776 | + return self._call_java("predict", features) |
| 777 | + |
| 778 | + |
612 | 779 | if __name__ == "__main__": |
613 | 780 | import doctest |
614 | 781 | from pyspark.context import SparkContext |
|
0 commit comments