-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Auc op #4063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auc op #4063
Changes from 8 commits
4d988ed
d1e6d55
0896418
f4e3134
399a5ee
c7eef34
12f0a86
bf7bc12
436b6ac
2824352
6330994
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/auc_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class AucOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| protected: | ||
| void InferShape(const framework::InferShapeContext &ctx) const override { | ||
| PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), | ||
| "Input of Inference must be initialized."); | ||
| PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), | ||
| "Input of Inference must be initialized."); | ||
| auto *inference = ctx.Input<framework::Tensor>("Inference"); | ||
| auto *label = ctx.Input<framework::Tensor>("Label"); | ||
|
|
||
| PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), | ||
| "inference and label should have same shape"); | ||
|
|
||
| ctx.Output<framework::LoDTensor>("AUC")->Resize({1}); | ||
| } | ||
| }; | ||
|
|
||
| class AucOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("Inference", | ||
| "A floating point `Tensor` of arbitrary shape and whose values" | ||
|
||
| "are in the range `[0, 1]`."); | ||
| AddInput("Label", | ||
| "A `Tensor` whose shape matches " | ||
| "`Inference`. Will be cast to `bool`."); | ||
| // TODO(typhoonzero): support weight input | ||
| AddOutput("AUC", | ||
| "A scalar `Tensor` representing the " | ||
| "current area-under-curve."); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to tell users "what is it" before "how to represent it".
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct, but in this line, "what is it", the output is a scalar indeed, and "representing ..." is what it does. |
||
|
|
||
| AddAttr<std::string>("curve", "Possible curves are ROC and PR") | ||
|
||
| .SetDefault("ROC"); | ||
| AddAttr<int>("num_thresholds", | ||
| "The number of thresholds to use when discretizing the" | ||
| " roc curve.") | ||
| .SetDefault(200); | ||
|
|
||
| AddComment( | ||
| R"DOC(Computes the AUC according forward output and label. | ||
| Best to use for binary classification evaluations. | ||
| If `label` can be values other than 0 and 1, it will be cast | ||
| to bool. | ||
|
|
||
| You can find the definations here: | ||
| https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve | ||
|
|
||
| Possible curves are: | ||
| - ROC: Receiver operating characteristic | ||
| - PR: Precision Recall | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); | ||
| REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <iostream> | ||
| #include "paddle/framework/eigen.h" | ||
| #include "paddle/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
|
|
||
| template <typename T, int MajorType = Eigen::RowMajor, | ||
| typename IndexType = Eigen::DenseIndex> | ||
| using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||
|
|
||
| template <typename Place, typename T> | ||
| class AucKernel : public framework::OpKernel { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||
| auto* inference = ctx.Input<Tensor>("Inference"); | ||
| auto* label = ctx.Input<Tensor>("Label"); | ||
| auto* auc = ctx.Output<Tensor>("AUC"); | ||
|
|
||
| float* auc_data = auc->mutable_data<float>(ctx.GetPlace()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float or T?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Evaluator output is always float. |
||
|
|
||
| std::string curve = ctx.Attr<std::string>("curve"); | ||
| int num_thresholds = ctx.Attr<int>("num_thresholds"); | ||
| std::vector<float> thresholds_list; | ||
| thresholds_list.reserve(num_thresholds); | ||
| for (int i = 1; i < num_thresholds - 1; i++) { | ||
| thresholds_list[i] = (float)i / (num_thresholds - 1); | ||
| } | ||
| const float kEpsilon = 1e-7; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overflow the accuracy range?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry didn't get your point?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, maybe float type can have 7 significant digits, I'm not sure about this. |
||
| thresholds_list[0] = 0.0f - kEpsilon; | ||
| thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; | ||
|
|
||
| size_t num_samples = inference->numel(); | ||
|
|
||
| const T* inference_data = inference->data<T>(); | ||
| Tensor label_casted; | ||
| label_casted.Resize(label->dims()); | ||
| bool* label_casted_data = label_casted.mutable_data<bool>(ctx.GetPlace()); | ||
|
|
||
| const int* label_data = label->data<int>(); | ||
| // cast label_data to bool | ||
| for (size_t i = 0; i < num_samples; i++) { | ||
| label_casted_data[i] = static_cast<bool>(label_data[i]); | ||
| } | ||
|
|
||
| // Create local tensor for storing the curve: TP, FN, TN, FP | ||
| // TODO(typhoonzero): put these tensors in Scope | ||
| // TODO(typhoonzero): use op to caculate these values. | ||
| Tensor true_positive, false_positive, true_negative, false_negative; | ||
|
|
||
| true_positive.Resize({num_thresholds}); | ||
| false_negative.Resize({num_thresholds}); | ||
| true_negative.Resize({num_thresholds}); | ||
| false_positive.Resize({num_thresholds}); | ||
|
|
||
| int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace()); | ||
| int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace()); | ||
| int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace()); | ||
| int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace()); | ||
|
|
||
| for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { | ||
| // caculate TP, FN, TN, FP for current thresh | ||
| int tp = 0, fn = 0, tn = 0, fp = 0; | ||
| for (size_t i = 0; i < num_samples; i++) { | ||
| if (label_casted_data[i]) { | ||
| if (inference_data[i] >= (thresholds_list[idx_thresh])) { | ||
| tp++; | ||
| } else { | ||
| fn++; | ||
| } | ||
| } else { | ||
| if (inference_data[i] >= (thresholds_list[idx_thresh])) { | ||
| fp++; | ||
| } else { | ||
| tn++; | ||
| } | ||
| } | ||
| } | ||
| // store rates | ||
| tp_data[idx_thresh] = tp; | ||
| fn_data[idx_thresh] = fn; | ||
| tn_data[idx_thresh] = tn; | ||
| fp_data[idx_thresh] = fp; | ||
| } | ||
| // epsilon to avoid divide by zero. | ||
| float epsilon = 1e-6; | ||
| // Riemann sum to caculate auc. | ||
| Tensor tp_rate, fp_rate, rec_rate; | ||
| tp_rate.Resize({num_thresholds}); | ||
| fp_rate.Resize({num_thresholds}); | ||
| rec_rate.Resize({num_thresholds}); | ||
| float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace()); | ||
| float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace()); | ||
| float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace()); | ||
| for (int i = 0; i < num_thresholds; i++) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can convert Tensor to Eigen::Tensor, and do vector computation instead of loop.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added to TODO, will refine if eigen have enough operators.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/elementwise_div_op.h#L71 |
||
| tp_rate_data[i] = | ||
| ((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon); | ||
| fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); | ||
| rec_rate_data[i] = | ||
| ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); | ||
| } | ||
| *auc_data = 0.0f; | ||
| if (curve == "ROC") { | ||
| for (int i = 0; i < num_thresholds - 1; i++) { | ||
| auto dx = fp_rate_data[i] - fp_rate_data[i + 1]; | ||
| auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f; | ||
| *auc_data = *auc_data + dx * y; | ||
| } | ||
| } else if (curve == "PR") { | ||
| for (int i = 1; i < num_thresholds; i++) { | ||
| auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; | ||
| auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; | ||
| *auc_data = *auc_data + dx * y; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import unittest | ||
| import numpy as np | ||
| from op_test import OpTest | ||
|
|
||
|
|
||
| class TestAucOp(OpTest): | ||
| def setUp(self): | ||
| self.op_type = "auc" | ||
| pred = np.random.random((128)).astype("float32") | ||
| labels = np.random.randint(0, 2, (128, )) | ||
| num_thresholds = 200 | ||
| self.inputs = {'Inference': pred, 'Label': labels} | ||
| self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} | ||
| # NOTE: sklearn use a different way to generate thresholds | ||
| # which will cause the result differs slightly: | ||
| # from sklearn.metrics import roc_curve, auc | ||
| # fpr, tpr, thresholds = roc_curve(labels, pred) | ||
| # auc_value = auc(fpr, tpr) | ||
| # we caculate AUC again using numpy for testing | ||
| kepsilon = 1e-7 # to account for floating point imprecisions | ||
| thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) | ||
| for i in range(num_thresholds - 2)] | ||
| thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] | ||
|
|
||
| # caculate TP, FN, TN, FP count | ||
| tp_list = np.ndarray((num_thresholds, )) | ||
| fn_list = np.ndarray((num_thresholds, )) | ||
| tn_list = np.ndarray((num_thresholds, )) | ||
| fp_list = np.ndarray((num_thresholds, )) | ||
| for idx_thresh, thresh in enumerate(thresholds): | ||
| tp, fn, tn, fp = 0, 0, 0, 0 | ||
| for i, lbl in enumerate(labels): | ||
| if lbl: | ||
| if pred[i] >= thresh: | ||
| tp += 1 | ||
| else: | ||
| fn += 1 | ||
| else: | ||
| if pred[i] >= thresh: | ||
| fp += 1 | ||
| else: | ||
| tn += 1 | ||
| tp_list[idx_thresh] = tp | ||
| fn_list[idx_thresh] = fn | ||
| tn_list[idx_thresh] = tn | ||
| fp_list[idx_thresh] = fp | ||
|
|
||
| epsilon = 1e-6 | ||
| tpr = (tp_list.astype("float32") + epsilon) / ( | ||
| tp_list + fn_list + epsilon) | ||
| fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) | ||
| rec = (tp_list.astype("float32") + epsilon) / ( | ||
| tp_list + fp_list + epsilon) | ||
|
|
||
| x = fpr[:num_thresholds - 1] - fpr[1:] | ||
| y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 | ||
| auc_value = np.sum(x * y) | ||
|
|
||
| self.outputs = {'AUC': auc_value} | ||
|
|
||
| def test_check_output(self): | ||
| self.check_output() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input or Label?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.