Add mean IOU op.#10519
Conversation
|
请先不要review,我再优化下GPU kernel. |
1. Merge computing in GPU to two kernel. 2. Use wrong array and correct array instead of confusion matrix.
| "A Tensor representing the" | ||
| " mean intersection-over-union."); | ||
| AddOutput("out_wrong", "A Tensor with shape [num_classes]. "); | ||
| AddOutput("out_correct", "A Tensor with shape [num_classes]. "); |
There was a problem hiding this comment.
Please follow https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/dev/name_convention.md#opprotomaker-names for all the inputs and outputs name.
| AddComment(R"DOC( | ||
| mean-IOU Operator. | ||
| Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, which first computes the IOU for each semantic class and then computes the average over classes. IOU is defined as follows: IOU = true_positive / (true_positive + false_positive + false_negative). The predictions are accumulated in a confusion matrix and mean-IOU is then calculated from it. | ||
|
|
There was a problem hiding this comment.
Since we have iou_similarity_op:
The doc here better to give more details for the difference.
| .AsDispensable(); | ||
| AddOutput("out_mean_iou", | ||
| "A Tensor representing the" | ||
| " mean intersection-over-union."); |
There was a problem hiding this comment.
Also need to give the shape.
| REGISTER_OPERATOR(mean_iou, ops::MeanIoUOp, ops::MeanIoUOpMaker, | ||
| paddle::framework::EmptyGradOpMaker); | ||
| REGISTER_OP_CPU_KERNEL(mean_iou, ops::MeanIoUKernel<int>, | ||
| ops::MeanIoUKernel<int64_t>); |
There was a problem hiding this comment.
上面文档描述里是支持int32和int64,这里没有注册int32。
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>, | ||
| ops::MeanIoUKernel<int64_t>); |
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("predictions", | ||
| "A Tensor of prediction results for semantic labels" | ||
| " with type int32 or int64."); |
| float* out_mean_iou_data = | ||
| out_mean_iou->mutable_data<float>(ctx.GetPlace()); | ||
|
|
||
| // get eigen tensor |
| auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong); | ||
| auto out_correct_t = EigenTensor<int, 1>::From(*out_correct); | ||
|
|
||
| // Tmp tensor |
| .AsDispensable(); | ||
| AddInput("in_mean_iou", | ||
| "A list of Tensor that Output(mean_iou) should " | ||
| "be added to. Empty list is also valid here.") |
There was a problem hiding this comment.
in_wrongs, in_corrects, in_mean_iou是干啥的?和out_wrong/correct/mean_iou有啥区别?
There was a problem hiding this comment.
in_wrongs, in_corrects, in_mean_iou之前当前batch之前累计的数据,加上当前batch的统计结果,就得到:out_wrong/correct/mean_iou
| for (int i = threadIdx.x; i < num_classes; i += blockDim.x) { | ||
| atomicAdd(wrong + i, wrong_c[i]); | ||
| atomicAdd(correct + i, correct_c[i]); | ||
| } |
There was a problem hiding this comment.
如果num_classes较小, predictions的shape较大,会导致这个kernel的性能非常低效,其实感觉类似这样的kernel,先CPU即可,后续最好评估下时间。
There was a problem hiding this comment.
| input size | class_num | GPU | CPU |
|---|---|---|---|
| 1024 * 2048 | 100 | 0.168812ms | 13.0831ms |
| 1024 * 2048 | 50 | 0.172748ms | 13.5145ms |
| 1024 * 2048 | 20 | 0.174807ms | 14.4619ms |
| 1024 * 2048 | 10 | 0.188483ms | 16.1516ms |
| 1024 * 2048 | 1 | 0.230743ms | 12.7893ms |
| 1024 * 2048 *2 | 100 | 0.308306ms | 26.4576 |
| 1024 * 2048 * 2 | 50 | 0.326073ms | 26.9835ms |
| 1024 * 2048 *2 | 20 | 0.28971ms | 29.0224ms |
| 1024 * 2048* 2 | 10 | 0.267694ms | 34.2029ms |
| 1024 * 2048 * 2 | 1 | 0.295844ms | 25.4808ms |
| 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', | ||
| 'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad', | ||
| 'label_smooth', 'roi_pool', 'dice_loss', 'image_resize', | ||
| 'image_resize_short', 'resize_bilinear', 'gather', 'random_crop', 'mean_iou' |
There was a problem hiding this comment.
This is due to he yapf version?
There was a problem hiding this comment.
The version of my yapf is 0.22. I'm not sure it is due to he yapf version. Which style is correct?
Performance on GPU P40:
test code: