Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,25 @@ limitations under the License. */
namespace paddle {
namespace operators {

__global__ void AccuracySingleKernel(const int N, const int D, const int top_k,
const int* Xdata, const int* labelData,
float* accuracy) {
int correct = 0;
for (int row = 0; row < N; row++) {
const int label = labelData[row];
for (int col = 0; col < D; col++) {
const int pred = Xdata[row * D + col];
if (pred == label) {
++correct;
__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use template <int BlockSize> is better than use PADDLE_CUDA_NUM_THREADS macro.

Copy link
Contributor Author

@typhoonzero typhoonzero Sep 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer PADDLE_CUDA_NUM_THREADS, it's a constexpr.

If use template <int BlockSize> we have to pass BlockSize twice to when calling the kernel like: SomeKernel<BlockSize><<<1, BlockSize>>>(), and seems PADDLE_CUDA_NUM_THREADS can be always the same const value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I think we can open an issue to discuss and then write some CUDA development docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the PADDLE_CUDA_NUM_THREADS macro is also equivalent to using it twice.

const int* labeldata, float* accuracy) {
int count = 0;
__shared__ int total;
total = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (N);
i += blockDim.x * gridDim.x) {
for (int j = 0; j < D; ++j) {
if (Xdata[i * D + j] == labeldata[i]) {
++count;
break;
}
}
}
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
atomicAdd(&total, count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOL. It seems too complicated to me. It looks correct but I really don't have enough experience to review. Please refer to @hedaoyuan review it, or write it base on some high-level library thrust.

__syncthreads();
if (threadIdx.x == 0) {
Copy link
Contributor

@hedaoyuan hedaoyuan Sep 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when num_samples is greater than 4096 there will be multiple blocks, and this result may be incorrect. I think can consider using only one block to calculate.
Try to avoid using atomicAdd, the __shared__ int total; can be replaced by __shared__ int total[block_size]; and add reduce at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for the advise, I'll try to fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

*accuracy = static_cast<float>(total) / static_cast<float>(N);
}
}

template <typename T>
Expand All @@ -57,8 +61,10 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}

AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
label_data, accuracy_data);
int threads = 512;
int grids = (num_samples + 4096 - 1) / 4096;
AccuracyCudaKernel<<<grids, threads>>>(
num_samples, infer_width, inference_data, label_data, accuracy_data);
}
};

Expand Down