Refine accuracy_op CUDA kernel#4097
Conversation
paddle/operators/accuracy_op.cu
Outdated
| } | ||
| } | ||
| *accuracy = static_cast<float>(correct) / static_cast<float>(N); | ||
| atomicAdd(&total, count); |
There was a problem hiding this comment.
LOL. It seems too complicated to me. It looks correct but I really don't have enough experience to review. Please refer to @hedaoyuan review it, or write it base on some high-level library thrust.
|
If you are interested in thrust, I will give a demo later. |
| *accuracy = static_cast<float>(correct) / static_cast<float>(N); | ||
| atomicAdd(&total, count); | ||
| __syncthreads(); | ||
| if (threadIdx.x == 0) { |
There was a problem hiding this comment.
It seems that when num_samples is greater than 4096 there will be multiple blocks, and this result may be incorrect. I think can consider using only one block to calculate.
Try to avoid using atomicAdd, the __shared__ int total; can be replaced by __shared__ int total[block_size]; and add reduce at the end.
There was a problem hiding this comment.
Thanks very much for the advise, I'll try to fix it.
| const int pred = Xdata[row * D + col]; | ||
| if (pred == label) { | ||
| ++correct; | ||
| __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, |
There was a problem hiding this comment.
Use template <int BlockSize> is better than use PADDLE_CUDA_NUM_THREADS macro.
There was a problem hiding this comment.
I would prefer PADDLE_CUDA_NUM_THREADS, it's a constexpr.
If use template <int BlockSize> we have to pass BlockSize twice to when calling the kernel like: SomeKernel<BlockSize><<<1, BlockSize>>>(), and seems PADDLE_CUDA_NUM_THREADS can be always the same const value.
There was a problem hiding this comment.
Well, I think we can open an issue to discuss and then write some CUDA development docs.
There was a problem hiding this comment.
Using the PADDLE_CUDA_NUM_THREADS macro is also equivalent to using it twice.
paddle/operators/accuracy_op.cu
Outdated
| __shared__ int total[PADDLE_CUDA_NUM_THREADS]; | ||
|
|
||
| // support only 1 block | ||
| for (int i = threadIdx.x; i < (N); i += blockDim.x * gridDim.x) { |
There was a problem hiding this comment.
i += blockDim.x * gridDim.x -> i += BlockSize
Use BlockSize better than blockDim.x (BlockSize is the value at compile time).
|
I had tried to write a demo code to show how thrust use in that case. I find that it's also very painful to write it. #include <thrust/copy.h>
#include <thrust/tuple.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>
#include <thrust/execution_policy.h>
#include <string.h>
#include <iostream>
void accuracy(int* data, int* label, int length, float* acc) {
acc = .0;
typedef thrust::device_vector<int>::iterator Iter;
typedef thrust::tuple<Iter, Iter> IterPair;
thrust::device_vector<int> data_device(data, data+length);
thrust::device_vector<int> label_device(label, label+length);
IterPair first = thrust::make_zip_iterator(thrust::make_tuple(data_device.begin(), label_device.begin()));
IterPair last = thrust::make_zip_iterator(thrust::make_tuple(data_device.end(), label_device.end()));
thrust::equal<thrust::tuple<int,int>> binary_op;
thrust::device_vector<int> correct(length);
thrust::transform(first, last, correct.begin(), binary_op());
int result = thrust::reduce(thrust::host, correct, correct+length);
if (result != 0) acc = (float)result/correct.size();
}related document zip_iterator |
|
@dzhwinter Thanks for the very useful code! Both ways are OK, either use thrust or not, I also prefer using |
Fix #4096