Skip to content

Commit e29fb5c

Browse files
authored
Add BatchParallelFor, TryParallelFor, TryBatchParallelFor into ThreadPool (#2476)
1 parent d6c8492 commit e29fb5c

File tree

5 files changed

+177
-34
lines changed

5 files changed

+177
-34
lines changed

include/onnxruntime/core/platform/threadpool.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ class ThreadPool {
4646
*/
4747
void ParallelFor(int32_t total, std::function<void(int32_t)> fn);
4848

49+
/*
50+
Schedule work in the interval [0, total), with calls split into (num_batches) batches.
51+
*/
52+
void BatchParallelFor(int32_t total, std::function<void(int32_t)> fn, int32_t num_batches = 0);
53+
4954
/*
5055
Schedule work in the interval [first, last].
5156
*/
@@ -54,6 +59,43 @@ class ThreadPool {
5459
// This is not supported until the latest Eigen
5560
// void SetStealPartitions(const std::vector<std::pair<unsigned, unsigned>>& partitions);
5661

62+
/**
63+
Tries to call the given function in parallel, with calls split into (num_batches) batches.
64+
**/
65+
template <typename F>
66+
inline static void TryBatchParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn, int32_t num_batches = 0) {
67+
if (tp != nullptr) {
68+
if (num_batches <= 0) {
69+
num_batches = tp->NumThreads() + 1;
70+
}
71+
tp->BatchParallelFor(total, std::forward<F>(fn), num_batches);
72+
} else {
73+
#ifdef USE_OPENMP
74+
#pragma omp parallel for
75+
#endif
76+
for (int32_t i = 0; i < total; ++i) {
77+
fn(i);
78+
}
79+
}
80+
}
81+
82+
/**
83+
Tries to call the given function in parallel.
84+
**/
85+
template <typename F>
86+
inline static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, F&& fn) {
87+
if (tp != nullptr) {
88+
tp->ParallelFor(total, std::forward<F>(fn));
89+
} else {
90+
#ifdef USE_OPENMP
91+
#pragma omp parallel for
92+
#endif
93+
for (int32_t i = 0; i < total; ++i) {
94+
fn(i);
95+
}
96+
}
97+
}
98+
5799
int NumThreads() const;
58100

59101
int CurrentThreadId() const;

onnxruntime/contrib_ops/cpu/crop_and_resize.cc

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,6 @@ namespace contrib {
4343

4444
ADD_TYPED_CROPANDRESIZE_OP(float);
4545

46-
template <typename T>
47-
static void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) {
48-
if (tp != nullptr)
49-
tp->ParallelFor(total, fn);
50-
else {
51-
for (int32_t i = 0; i != total; ++i) {
52-
fn(i);
53-
}
54-
}
55-
}
56-
5746
template <typename T>
5847
void CropAndResizeForward(const TensorShape& output_shape,
5948
const T* bottom_data,
@@ -71,9 +60,7 @@ void CropAndResizeForward(const TensorShape& output_shape,
7160
int64_t pooled_height = output_shape[2];
7261
int64_t pooled_width = output_shape[3];
7362

74-
// TODO: This should do blocks of work based on the number of threads in the threadpool with each block
75-
// being n_rois / num_threads
76-
std::function<void(int32_t)> work_object = [&](int32_t n) {
63+
ThreadPool::TryBatchParallelFor(ttp, static_cast<int32_t>(n_rois), [&](int32_t n) {
7764
int64_t index_n = n * channels * pooled_width * pooled_height;
7865

7966
const T* offset_bottom_rois = bottom_rois + n * num_roi_cols;
@@ -182,9 +169,7 @@ void CropAndResizeForward(const TensorShape& output_shape,
182169
}
183170
} // for pw
184171
} // for ph
185-
}; // for n
186-
187-
TryParallelFor(ttp, static_cast<int32_t>(n_rois), work_object);
172+
}); // for n
188173
}
189174

190175
template <typename T>

onnxruntime/core/common/threadpool.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,36 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
5757
barrier.Wait();
5858
}
5959

60+
void ThreadPool::BatchParallelFor(int32_t total, std::function<void(int32_t)> fn, int32_t num_batches) {
61+
if (total <= 0)
62+
return;
63+
64+
if (total == 1) {
65+
fn(0);
66+
return;
67+
}
68+
69+
if (num_batches <= 1) {
70+
for (int i = 0; i < total; i++) {
71+
fn(i);
72+
}
73+
return;
74+
}
75+
76+
if (num_batches >= total) {
77+
ParallelFor(total, fn);
78+
return;
79+
}
80+
81+
ParallelFor(num_batches, [&](int batch_index) {
82+
int start = batch_index * total / num_batches;
83+
int end = (batch_index + 1) * total / num_batches;
84+
for (int i = start; i < end; i++) {
85+
fn(i);
86+
}
87+
});
88+
}
89+
6090
void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function<void(int64_t, int64_t)> fn) {
6191
if (last <= first) return;
6292
if (last - first == 1) {

onnxruntime/core/providers/cpu/object_detection/roialign.cc

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,6 @@ ADD_TYPED_ROIALIGN_OP(float);
4242
ADD_TYPED_ROIALIGN_OP(double);
4343

4444
namespace {
45-
template <typename T>
46-
void TryParallelFor(concurrency::ThreadPool* tp, int32_t total, T&& fn) {
47-
if (tp != nullptr)
48-
tp->ParallelFor(total, fn);
49-
else {
50-
for (int32_t i = 0; i != total; ++i) {
51-
fn(i);
52-
}
53-
}
54-
}
55-
5645
template <typename T>
5746
struct PreCalc {
5847
int64_t pos1;
@@ -183,9 +172,7 @@ void RoiAlignForward(const TensorShape& output_shape,
183172
int64_t pooled_height = output_shape[2];
184173
int64_t pooled_width = output_shape[3];
185174

186-
// TODO: This should do blocks of work based on the number of threads in the threadpool with each block
187-
// being n_rois / num_threads
188-
std::function<void(int32_t)> work_object = [&](int32_t n) {
175+
ThreadPool::TryBatchParallelFor(ttp, static_cast<int32_t>(n_rois), [&](int32_t n) {
189176
int64_t index_n = n * channels * pooled_width * pooled_height;
190177

191178
const T* offset_bottom_rois = bottom_rois + n * num_roi_cols;
@@ -281,9 +268,7 @@ void RoiAlignForward(const TensorShape& output_shape,
281268
} // for pw
282269
} // for ph
283270
} // for c
284-
}; // for n
285-
286-
TryParallelFor(ttp, static_cast<int32_t>(n_rois), work_object);
271+
}); // for n
287272
}
288273
} // namespace
289274

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/platform/threadpool.h"
5+
6+
#include <core/common/make_unique.h>
7+
8+
#include "gtest/gtest.h"
9+
#include <algorithm>
10+
#include <memory>
11+
#include <functional>
12+
#include <mutex>
13+
14+
using namespace onnxruntime::concurrency;
15+
16+
namespace {
17+
18+
struct TestData {
19+
explicit TestData(int num) : data(num, 0) {}
20+
std::vector<int> data;
21+
std::mutex mutex;
22+
};
23+
24+
// This unittest tests ThreadPool function by counting the number of calls to function with each index.
25+
// the function should be called exactly once for each element.
26+
27+
std::unique_ptr<TestData> CreateTestData(int num) {
28+
return onnxruntime::make_unique<TestData>(num);
29+
}
30+
31+
void IncrementElement(TestData& test_data, int i) {
32+
std::lock_guard<std::mutex> lock(test_data.mutex);
33+
test_data.data[i]++;
34+
}
35+
36+
void ValidateTestData(TestData& test_data) {
37+
ASSERT_TRUE(std::count_if(test_data.data.cbegin(),
38+
test_data.data.cend(),
39+
[](int i) { return i != 1; }) == 0);
40+
}
41+
42+
void CreateThreadPoolAndTest(const std::string& name, int num_threads, const std::function<void(ThreadPool*)>& test_body) {
43+
auto tp = onnxruntime::make_unique<ThreadPool>(name, num_threads);
44+
test_body(tp.get());
45+
}
46+
47+
void TestParallelFor(const std::string& name, int num_threads, int num_tasks) {
48+
auto test_data = CreateTestData(num_tasks);
49+
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
50+
tp->ParallelFor(num_tasks, [&](int i) {
51+
IncrementElement(*test_data, i);
52+
});
53+
});
54+
ValidateTestData(*test_data);
55+
}
56+
57+
void TestBatchParallelFor(const std::string& name, int num_threads, int num_tasks, int batch_size) {
58+
auto test_data = CreateTestData(num_tasks);
59+
CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) {
60+
tp->BatchParallelFor(
61+
num_tasks, [&](int i) {
62+
IncrementElement(*test_data, i);
63+
},
64+
batch_size);
65+
});
66+
ValidateTestData(*test_data);
67+
}
68+
69+
} // namespace
70+
71+
TEST(ThreadPoolTest, TestParallelFor_2_Thread_NoTask) {
72+
TestParallelFor("TestParallelFor_2_Thread_NoTask", 2, 0);
73+
}
74+
75+
TEST(ThreadPoolTest, TestParallelFor_2_Thread_50_Task) {
76+
TestParallelFor("TestParallelFor_2_Thread_50_Task", 2, 50);
77+
}
78+
79+
TEST(ThreadPoolTest, TestParallelFor_1_Thread_50_Task) {
80+
TestParallelFor("TestParallelFor_1_Thread_50_Task", 1, 50);
81+
}
82+
83+
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_10_Batch) {
84+
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_10_Batch", 2, 50, 10);
85+
}
86+
87+
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_0_Batch) {
88+
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_0_Batch", 2, 50, 0);
89+
}
90+
91+
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_1_Batch) {
92+
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_1_Batch", 2, 50, 1);
93+
}
94+
95+
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_50_Task_100_Batch) {
96+
TestBatchParallelFor("TestBatchParallelFor_2_Thread_50_Task_100_Batch", 2, 50, 100);
97+
}
98+
99+
TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_81_Task_20_Batch) {
100+
TestBatchParallelFor("TestBatchParallelFor_2_Thread_81_Task_20_Batch", 2, 81, 20);
101+
}

0 commit comments

Comments
 (0)