-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Refine accuracy_op CUDA kernel #4097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,21 +17,25 @@ limitations under the License. */ | |
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| __global__ void AccuracySingleKernel(const int N, const int D, const int top_k, | ||
| const int* Xdata, const int* labelData, | ||
| float* accuracy) { | ||
| int correct = 0; | ||
| for (int row = 0; row < N; row++) { | ||
| const int label = labelData[row]; | ||
| for (int col = 0; col < D; col++) { | ||
| const int pred = Xdata[row * D + col]; | ||
| if (pred == label) { | ||
| ++correct; | ||
| __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, | ||
| const int* labeldata, float* accuracy) { | ||
| int count = 0; | ||
| __shared__ int total; | ||
| total = 0; | ||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (N); | ||
| i += blockDim.x * gridDim.x) { | ||
| for (int j = 0; j < D; ++j) { | ||
| if (Xdata[i * D + j] == labeldata[i]) { | ||
| ++count; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| *accuracy = static_cast<float>(correct) / static_cast<float>(N); | ||
| atomicAdd(&total, count); | ||
|
||
| __syncthreads(); | ||
| if (threadIdx.x == 0) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that when num_samples is greater than 4096 there will be multiple blocks, and this result may be incorrect. I think can consider using only one block to calculate.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks very much for the advise, I'll try to fix it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| *accuracy = static_cast<float>(total) / static_cast<float>(N); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
|
|
@@ -57,8 +61,10 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { | |
| return; | ||
| } | ||
|
|
||
| AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data, | ||
| label_data, accuracy_data); | ||
| int threads = 512; | ||
| int grids = (num_samples + 4096 - 1) / 4096; | ||
| AccuracyCudaKernel<<<grids, threads>>>( | ||
| num_samples, infer_width, inference_data, label_data, accuracy_data); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
template <int BlockSize>is better than usePADDLE_CUDA_NUM_THREADSmacro.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer
PADDLE_CUDA_NUM_THREADS, it's a constexpr.If use
template <int BlockSize>we have to passBlockSizetwice to when calling the kernel like:SomeKernel<BlockSize><<<1, BlockSize>>>(), and seemsPADDLE_CUDA_NUM_THREADScan be always the same const value.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I think we can open an issue to discuss and then write some CUDA development docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the
PADDLE_CUDA_NUM_THREADSmacro is also equivalent to using it twice.