Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;

__global__ void AccuracySingleKernel(const int N, const int D, const int top_k,
const int* Xdata, const int* labelData,
float* accuracy) {
int correct = 0;
for (int row = 0; row < N; row++) {
const int label = labelData[row];
for (int col = 0; col < D; col++) {
const int pred = Xdata[row * D + col];
if (pred == label) {
++correct;
__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use template <int BlockSize> is better than use PADDLE_CUDA_NUM_THREADS macro.

Copy link
Contributor Author

@typhoonzero typhoonzero Sep 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer PADDLE_CUDA_NUM_THREADS, it's a constexpr.

If use template <int BlockSize> we have to pass BlockSize twice to when calling the kernel like: SomeKernel<BlockSize><<<1, BlockSize>>>(), and seems PADDLE_CUDA_NUM_THREADS can be always the same const value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I think we can open an issue to discuss and then write some CUDA development docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the PADDLE_CUDA_NUM_THREADS macro is also equivalent to using it twice.

const int* labeldata, float* accuracy) {
int count = 0;
__shared__ int total[PADDLE_CUDA_NUM_THREADS];

// support only 1 block
for (int i = threadIdx.x; i < (N); i += blockDim.x * gridDim.x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i += blockDim.x * gridDim.x -> i += BlockSize
Use BlockSize better than blockDim.x (BlockSize is the value at compile time).

for (int j = 0; j < D; ++j) {
if (Xdata[i * D + j] == labeldata[i]) {
++count;
break;
}
}
}
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
total[threadIdx.x] = count;
__syncthreads();

// reduce the count with init value 0, and output accuracy.
int result =
thrust::reduce(thrust::device, total, total + PADDLE_CUDA_NUM_THREADS, 0);
if (threadIdx.x == 0) {
Copy link
Contributor

@hedaoyuan hedaoyuan Sep 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when num_samples is greater than 4096 there will be multiple blocks, and this result may be incorrect. I think can consider using only one block to calculate.
Try to avoid using atomicAdd, the __shared__ int total; can be replaced by __shared__ int total[block_size]; and add reduce at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for the advise, I'll try to fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

*accuracy = static_cast<float>(result) / static_cast<float>(N);
}
}

template <typename T>
Expand All @@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}

AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
label_data, accuracy_data);
AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(
num_samples, infer_width, inference_data, label_data, accuracy_data);
}
};

Expand Down
5 changes: 5 additions & 0 deletions paddle/platform/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ namespace platform {
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
// to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;

// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);

Expand Down
9 changes: 5 additions & 4 deletions python/paddle/v2/framework/tests/test_accuracy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
class TestAccuracyOp(OpTest):
def setUp(self):
self.op_type = "accuracy"
infer = np.random.randint(0, 2, (32, 1)).astype("int")
label = np.random.randint(0, 2, (32, )).astype("int")
n = 8192
infer = np.random.randint(0, 2, (n, 1)).astype("int")
label = np.random.randint(0, 2, (n, )).astype("int")
self.inputs = {'Inference': infer, "Label": label}
num_correct = 0
for rowid in xrange(32):
for rowid in xrange(n):
for ele in infer[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {'Accuracy': [num_correct / 32.0]}
self.outputs = {'Accuracy': [num_correct / float(n)]}

def test_check_output(self):
self.check_output()
Expand Down