Skip to content

Commit 66c19b4

Browse files
committed
[XPU] refactor thread_local
1 parent ff9406b commit 66c19b4

File tree

13 files changed

+483
-177
lines changed

13 files changed

+483
-177
lines changed

lite/api/cxx_api.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ void Predictor::SaveOpKernelInfo(const std::string &model_dir) {
182182

183183
#if !defined(LITE_WITH_METAL)
184184
lite::Tensor *Predictor::GetInput(size_t offset) {
185+
#ifdef LITE_WITH_XPU
186+
XPU_CALL(xpu_set_device(reinterpret_cast<lite::XPURunTimeOption *>(
187+
runtime_options_[TARGET(kXPU)].get())
188+
->xpu_dev_num));
189+
#endif
185190
CHECK(input_names_.size() > offset)
186191
<< "The network has " << input_names_.size() << " inputs"
187192
<< ", the offset should be less than this.";

lite/api/cxx_api.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ class LITE_API Predictor {
164164
CheckInputValid();
165165

166166
#ifdef LITE_WITH_XPU
167+
if (lite::TargetWrapperXPU::xpu_runtime_ptr !=
168+
runtime_options_[TARGET(kXPU)].get()) {
169+
lite::TargetWrapperXPU::xpu_runtime_ptr =
170+
reinterpret_cast<lite::XPURunTimeOption*>(
171+
runtime_options_[TARGET(kXPU)].get());
172+
// thanks to rumtime context is thread_local,so we should set device when
173+
// using different predictor in the same thread.
174+
XPU_CALL(
175+
xpu_set_device(lite::TargetWrapperXPU::xpu_runtime_ptr->xpu_dev_num));
176+
}
167177
std::vector<std::vector<int64_t>> query_shape;
168178
for (size_t i = 0; i < input_names_.size(); i++) {
169179
query_shape.push_back(std::vector<int64_t>(GetInput(i)->dims().data()));
@@ -237,6 +247,29 @@ class LITE_API Predictor {
237247
void CheckPaddleOpVersions(
238248
const std::shared_ptr<cpp::ProgramDesc>& program_desc);
239249

250+
void SetRunTimeOption(const lite_api::CxxConfig& config) {
251+
auto&& map_runtime_options = config.runtime_options();
252+
#ifdef LITE_WITH_XPU
253+
std::shared_ptr<void> runtime_option =
254+
std::shared_ptr<lite::XPURunTimeOption>(new lite::XPURunTimeOption);
255+
runtime_options_.emplace(TARGET(kXPU), std::move(runtime_option));
256+
if (map_runtime_options[TARGET(kXPU)].get()) {
257+
reinterpret_cast<lite::XPURunTimeOption*>(
258+
runtime_options_[TARGET(kXPU)].get())
259+
->Set(reinterpret_cast<const lite::XPURunTimeOption*>(
260+
map_runtime_options[TARGET(kXPU)].get()));
261+
}
262+
#endif
263+
}
264+
265+
#ifdef LITE_WITH_XPU
266+
void SetXPUStream(void* stream) {
267+
reinterpret_cast<lite::XPURunTimeOption*>(
268+
runtime_options_[TARGET(kXPU)].get())
269+
->xpu_stream.SetXPUStream(stream);
270+
}
271+
#endif
272+
240273
// #ifdef LITE_WITH_TRAIN
241274
// void Run(const std::vector<framework::Tensor>& tensors) {
242275
// FeedVars(tensors);
@@ -257,6 +290,8 @@ class LITE_API Predictor {
257290
#endif
258291

259292
private:
293+
std::map<TargetType, std::shared_ptr<void>> runtime_options_;
294+
260295
std::shared_ptr<cpp::ProgramDesc> program_desc_;
261296
std::shared_ptr<Scope> scope_;
262297
Scope* exec_scope_;
@@ -323,6 +358,9 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
323358
const std::string& model_dir,
324359
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
325360
bool record_info = false) override;
361+
#ifdef LITE_WITH_XPU
362+
void SetXPUStream(void* stream);
363+
#endif
326364

327365
private:
328366
std::shared_ptr<Predictor> raw_predictor_;

lite/api/cxx_api_impl.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
4444
config_ = config;
4545
mode_ = config.power_mode();
4646
threads_ = config.threads();
47+
raw_predictor_->SetRunTimeOption(config);
4748
#ifdef LITE_USE_THREAD_POOL
4849
int thread_num = ThreadPool::Init(threads_);
4950
if (thread_num > 1) {
@@ -278,6 +279,12 @@ bool CxxPaddleApiImpl::TryShrinkMemory() {
278279
return raw_predictor_->TryShrinkMemory();
279280
}
280281

282+
#ifdef LITE_WITH_XPU
283+
void CxxPaddleApiImpl::SetXPUStream(void *stream) {
284+
raw_predictor_->SetXPUStream(stream);
285+
}
286+
#endif
287+
281288
} // namespace lite
282289

283290
namespace lite_api {

lite/api/paddle_api.cc

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#ifdef LITE_WITH_XPU
2525
#include <functional>
2626
#include <mutex> // NOLINT
27+
#include "lite/backends/xpu/runtime_option.h"
2728
#include "lite/backends/xpu/target_wrapper.h"
2829
#endif
2930

@@ -266,6 +267,11 @@ ConfigBase::ConfigBase(PowerMode mode, int threads) {
266267
mode_ = lite::DeviceInfo::Global().mode();
267268
threads_ = lite::DeviceInfo::Global().threads();
268269
#endif
270+
#ifdef LITE_WITH_XPU
271+
std::shared_ptr<void> runtime_option =
272+
std::shared_ptr<lite::XPURunTimeOption>(new lite::XPURunTimeOption);
273+
runtime_options_.emplace(TARGET(kXPU), std::move(runtime_option));
274+
#endif
269275
}
270276

271277
void ConfigBase::set_opencl_binary_path_name(const std::string &path,
@@ -478,10 +484,14 @@ void CxxConfig::set_xpu_l3_cache_method(size_t l3_size, bool locked) {
478484
CHECK(lite::TargetWrapperXPU::shared_l3_size >= l3_size)
479485
<< "Enlarge XPU Shared L3 Cache Is Not Allowed.";
480486
}
481-
lite::TargetWrapperXPU::local_l3_size = 0;
487+
reinterpret_cast<lite::XPURunTimeOption *>(
488+
runtime_options()[TARGET(kXPU)].get())
489+
->xpu_local_l3_size = 0;
482490
lite::TargetWrapperXPU::need_l3_mutex = true;
483491
} else {
484-
lite::TargetWrapperXPU::local_l3_size = l3_size;
492+
reinterpret_cast<lite::XPURunTimeOption *>(
493+
runtime_options()[TARGET(kXPU)].get())
494+
->xpu_local_l3_size = l3_size;
485495
lite::TargetWrapperXPU::need_l3_mutex = false;
486496
}
487497
#else
@@ -493,17 +503,21 @@ void CxxConfig::set_xpu_l3_cache_method(size_t l3_size, bool locked) {
493503

494504
void CxxConfig::set_xpu_l3_cache_autotune(bool autotune) {
495505
#ifdef LITE_WITH_XPU
496-
lite::TargetWrapperXPU::local_l3_autotune = autotune;
506+
reinterpret_cast<lite::XPURunTimeOption *>(
507+
runtime_options()[TARGET(kXPU)].get())
508+
->xpu_local_l3_autotune = autotune;
497509
#else
498510
LOG(WARNING) << "The invoking of the function "
499511
"'set_xpu_l3_cache_autotune' is ignored, please "
500512
"rebuild it with LITE_WITH_XPU=ON.";
501513
#endif
502514
}
503515

504-
void set_xpu_gm_workspace_method(size_t gm_size) {
516+
void CxxConfig::set_xpu_gm_workspace_method(size_t gm_size) {
505517
#ifdef LITE_WITH_XPU
506-
lite::TargetWrapperXPU::local_gm_size = gm_size;
518+
reinterpret_cast<lite::XPURunTimeOption *>(
519+
runtime_options()[TARGET(kXPU)].get())
520+
->xpu_local_gm_size = gm_size;
507521
#else
508522
LOG(WARNING) << "The invoking of the function "
509523
"'set_xpu_gm_workspace_method' is ignored, please "
@@ -513,7 +527,9 @@ void set_xpu_gm_workspace_method(size_t gm_size) {
513527

514528
void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
515529
#ifdef LITE_WITH_XPU
516-
lite::TargetWrapperXPU::SetDev(dev_no);
530+
reinterpret_cast<lite::XPURunTimeOption *>(
531+
runtime_options()[TARGET(kXPU)].get())
532+
->xpu_dev_num = dev_no;
517533
#else
518534
LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is "
519535
"ignored, please rebuild it with LITE_WITH_XPU=ON.";
@@ -522,7 +538,9 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
522538

523539
void CxxConfig::enable_xpu_multi_stream() {
524540
#ifdef LITE_WITH_XPU
525-
lite::TargetWrapperXPU::enable_xpu_multi_stream();
541+
reinterpret_cast<lite::XPURunTimeOption *>(
542+
runtime_options()[TARGET(kXPU)].get())
543+
->xpu_enable_multi_stream = true;
526544
#else
527545
LOG(WARNING)
528546
<< "The invoking of the function 'enable_xpu_stream_per_thread' is "
@@ -591,7 +609,9 @@ void CxxConfig::set_xpu_conv_autotune(bool autotune,
591609

592610
void CxxConfig::set_xpu_cluster_num(const int num) {
593611
#ifdef LITE_WITH_XPU
594-
lite::TargetWrapperXPU::cluster_num = num;
612+
reinterpret_cast<lite::XPURunTimeOption *>(
613+
runtime_options()[TARGET(kXPU)].get())
614+
->xpu_cluster_num = num;
595615
#else
596616
LOG(WARNING) << "The invoking of the function "
597617
"'set_xpu_cluster_num' is ignored, please "
@@ -601,14 +621,40 @@ void CxxConfig::set_xpu_cluster_num(const int num) {
601621

602622
void CxxConfig::set_xpu_sdnn_num(const int num) {
603623
#ifdef LITE_WITH_XPU
604-
lite::TargetWrapperXPU::sdnn_num = num;
624+
reinterpret_cast<lite::XPURunTimeOption *>(
625+
runtime_options()[TARGET(kXPU)].get())
626+
->xpu_sdnn_num = num;
605627
#else
606628
LOG(WARNING) << "The invoking of the function "
607629
"'set_xpu_sdnn_num' is ignored, please "
608630
"rebuild it with LITE_WITH_XPU=ON.";
609631
#endif
610632
}
611633

634+
void CxxConfig::set_xpu_dump_tensor_path(const std::string dump_tensor_path) {
635+
#ifdef LITE_WITH_XPU
636+
reinterpret_cast<lite::XPURunTimeOption *>(
637+
runtime_options()[TARGET(kXPU)].get())
638+
->xpu_dump_tensor_path = dump_tensor_path;
639+
#else
640+
LOG(WARNING) << "The invoking of the function "
641+
"'set_xpu_dump_tensor_path' is ignored, please "
642+
"rebuild it with LITE_WITH_XPU=ON.";
643+
#endif
644+
}
645+
646+
void CxxConfig::set_xpu_dump_log_path(const std::string dump_log_path) {
647+
#ifdef LITE_WITH_XPU
648+
reinterpret_cast<lite::XPURunTimeOption *>(
649+
runtime_options()[TARGET(kXPU)].get())
650+
->xpu_dump_log_path = dump_log_path;
651+
#else
652+
LOG(WARNING) << "The invoking of the function "
653+
"'set_xpu_dump_log_path' is ignored, please "
654+
"rebuild it with LITE_WITH_XPU=ON.";
655+
#endif
656+
}
657+
612658
template <class T>
613659
void CxxConfig::set_preferred_inputs_for_warmup(const int group_idx,
614660
const int tensor_idx,

lite/api/paddle_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class LITE_API ConfigBase {
185185
bool metal_use_memory_reuse_{false};
186186

187187
std::vector<std::string> discarded_passes_{};
188+
std::map<TargetType, std::shared_ptr<void>> runtime_options_;
188189

189190
public:
190191
explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1);
@@ -346,6 +347,9 @@ class LITE_API ConfigBase {
346347
const std::vector<std::string> get_discarded_passes() const {
347348
return discarded_passes_;
348349
}
350+
std::map<TargetType, std::shared_ptr<void>> runtime_options() const {
351+
return runtime_options_;
352+
}
349353
};
350354

351355
class LITE_API CxxModelBuffer {
@@ -455,6 +459,8 @@ class LITE_API CxxConfig : public ConfigBase {
455459
void set_xpu_sdnn_num(const int num);
456460
void set_xpu_local_quant(bool local_quant = false);
457461
void set_xpu_compute_precision(const std::string& precision = "int16");
462+
void set_xpu_dump_tensor_path(const std::string dump_tensor_path = "");
463+
void set_xpu_dump_log_path(const std::string dump_log_path = "");
458464

459465
// set input tensor for warmup.
460466
// It is optional. If you set prefered_inputs, model wil run immediately when

0 commit comments

Comments
 (0)