Skip to content

Commit f2dc28e

Browse files
xiuxin121newway
authored andcommitted
[XPU][Cherry-pick] refactor thread_local (PaddlePaddle#9817)
1 parent fe36f3c commit f2dc28e

File tree

13 files changed

+482
-177
lines changed

13 files changed

+482
-177
lines changed

lite/api/cxx_api.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ void Predictor::SaveOpKernelInfo(const std::string &model_dir) {
182182

183183
#if !defined(LITE_WITH_METAL)
184184
lite::Tensor *Predictor::GetInput(size_t offset) {
185+
#ifdef LITE_WITH_XPU
186+
XPU_CALL(xpu_set_device(reinterpret_cast<lite::XPURunTimeOption *>(
187+
target_configs_[TARGET(kXPU)].get())
188+
->xpu_dev_num));
189+
#endif
185190
CHECK(input_names_.size() > offset)
186191
<< "The network has " << input_names_.size() << " inputs"
187192
<< ", the offset should be less than this.";

lite/api/cxx_api.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ class LITE_API Predictor {
164164
CheckInputValid();
165165

166166
#ifdef LITE_WITH_XPU
167+
if (lite::TargetWrapperXPU::xpu_runtime_ptr !=
168+
target_configs_[TARGET(kXPU)].get()) {
169+
lite::TargetWrapperXPU::xpu_runtime_ptr =
170+
reinterpret_cast<lite::XPURunTimeOption*>(
171+
target_configs_[TARGET(kXPU)].get());
172+
// thanks to rumtime context is thread_local,so we should set device when
173+
// using different predictor in the same thread.
174+
XPU_CALL(
175+
xpu_set_device(lite::TargetWrapperXPU::xpu_runtime_ptr->xpu_dev_num));
176+
}
167177
std::vector<std::vector<int64_t>> query_shape;
168178
for (size_t i = 0; i < input_names_.size(); i++) {
169179
query_shape.push_back(std::vector<int64_t>(GetInput(i)->dims().data()));
@@ -237,6 +247,31 @@ class LITE_API Predictor {
237247
void CheckPaddleOpVersions(
238248
const std::shared_ptr<cpp::ProgramDesc>& program_desc);
239249

250+
void SetTargetConfigs(
251+
const std::map<TargetType, std::shared_ptr<void>>& target_configs) {
252+
#ifdef LITE_WITH_XPU
253+
std::shared_ptr<void> runtime_option =
254+
std::shared_ptr<lite::XPURunTimeOption>(new lite::XPURunTimeOption);
255+
target_configs_.emplace(TARGET(kXPU), std::move(runtime_option));
256+
if (target_configs.at(TARGET(kXPU)).get()) {
257+
reinterpret_cast<lite::XPURunTimeOption*>(
258+
target_configs_[TARGET(kXPU)].get())
259+
->Set(reinterpret_cast<const lite::XPURunTimeOption*>(
260+
target_configs.at(TARGET(kXPU)).get()));
261+
}
262+
#endif
263+
}
264+
265+
void SetStream(TargetType target, void* stream) {
266+
if (target == TARGET(kXPU)) {
267+
#ifdef LITE_WITH_XPU
268+
reinterpret_cast<lite::XPURunTimeOption*>(
269+
target_configs_[TARGET(kXPU)].get())
270+
->xpu_stream.SetXPUStream(stream);
271+
#endif
272+
}
273+
}
274+
240275
// #ifdef LITE_WITH_TRAIN
241276
// void Run(const std::vector<framework::Tensor>& tensors) {
242277
// FeedVars(tensors);
@@ -257,6 +292,8 @@ class LITE_API Predictor {
257292
#endif
258293

259294
private:
295+
std::map<TargetType, std::shared_ptr<void>> target_configs_;
296+
260297
std::shared_ptr<cpp::ProgramDesc> program_desc_;
261298
std::shared_ptr<Scope> scope_;
262299
Scope* exec_scope_;
@@ -324,6 +361,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
324361
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
325362
bool record_info = false) override;
326363

364+
void SetStream(TargetType target, void* stream);
365+
327366
private:
328367
std::shared_ptr<Predictor> raw_predictor_;
329368
lite_api::CxxConfig config_;

lite/api/cxx_api_impl.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
4444
config_ = config;
4545
mode_ = config.power_mode();
4646
threads_ = config.threads();
47+
raw_predictor_->SetTargetConfigs(config.target_configs());
4748
#ifdef LITE_USE_THREAD_POOL
4849
int thread_num = ThreadPool::Init(threads_);
4950
if (thread_num > 1) {
@@ -278,6 +279,10 @@ bool CxxPaddleApiImpl::TryShrinkMemory() {
278279
return raw_predictor_->TryShrinkMemory();
279280
}
280281

282+
void CxxPaddleApiImpl::SetStream(TargetType target, void *stream) {
283+
raw_predictor_->SetStream(target, stream);
284+
}
285+
281286
} // namespace lite
282287

283288
namespace lite_api {

lite/api/paddle_api.cc

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#ifdef LITE_WITH_XPU
2525
#include <functional>
2626
#include <mutex> // NOLINT
27+
#include "lite/backends/xpu/runtime_option.h"
2728
#include "lite/backends/xpu/target_wrapper.h"
2829
#endif
2930

@@ -264,6 +265,11 @@ ConfigBase::ConfigBase(PowerMode mode, int threads) {
264265
mode_ = lite::DeviceInfo::Global().mode();
265266
threads_ = lite::DeviceInfo::Global().threads();
266267
#endif
268+
#ifdef LITE_WITH_XPU
269+
std::shared_ptr<void> runtime_option =
270+
std::shared_ptr<lite::XPURunTimeOption>(new lite::XPURunTimeOption);
271+
target_configs_.emplace(TARGET(kXPU), std::move(runtime_option));
272+
#endif
267273
}
268274

269275
void ConfigBase::set_opencl_binary_path_name(const std::string &path,
@@ -483,10 +489,14 @@ void CxxConfig::set_xpu_l3_cache_method(size_t l3_size, bool locked) {
483489
CHECK(lite::TargetWrapperXPU::shared_l3_size >= l3_size)
484490
<< "Enlarge XPU Shared L3 Cache Is Not Allowed.";
485491
}
486-
lite::TargetWrapperXPU::local_l3_size = 0;
492+
reinterpret_cast<lite::XPURunTimeOption *>(
493+
target_configs()[TARGET(kXPU)].get())
494+
->xpu_local_l3_size = 0;
487495
lite::TargetWrapperXPU::need_l3_mutex = true;
488496
} else {
489-
lite::TargetWrapperXPU::local_l3_size = l3_size;
497+
reinterpret_cast<lite::XPURunTimeOption *>(
498+
target_configs()[TARGET(kXPU)].get())
499+
->xpu_local_l3_size = l3_size;
490500
lite::TargetWrapperXPU::need_l3_mutex = false;
491501
}
492502
#else
@@ -498,17 +508,21 @@ void CxxConfig::set_xpu_l3_cache_method(size_t l3_size, bool locked) {
498508

499509
void CxxConfig::set_xpu_l3_cache_autotune(bool autotune) {
500510
#ifdef LITE_WITH_XPU
501-
lite::TargetWrapperXPU::local_l3_autotune = autotune;
511+
reinterpret_cast<lite::XPURunTimeOption *>(
512+
target_configs()[TARGET(kXPU)].get())
513+
->xpu_local_l3_autotune = autotune;
502514
#else
503515
LOG(WARNING) << "The invoking of the function "
504516
"'set_xpu_l3_cache_autotune' is ignored, please "
505517
"rebuild it with LITE_WITH_XPU=ON.";
506518
#endif
507519
}
508520

509-
void set_xpu_gm_workspace_method(size_t gm_size) {
521+
void CxxConfig::set_xpu_gm_workspace_method(size_t gm_size) {
510522
#ifdef LITE_WITH_XPU
511-
lite::TargetWrapperXPU::local_gm_size = gm_size;
523+
reinterpret_cast<lite::XPURunTimeOption *>(
524+
target_configs()[TARGET(kXPU)].get())
525+
->xpu_local_gm_size = gm_size;
512526
#else
513527
LOG(WARNING) << "The invoking of the function "
514528
"'set_xpu_gm_workspace_method' is ignored, please "
@@ -518,7 +532,9 @@ void set_xpu_gm_workspace_method(size_t gm_size) {
518532

519533
void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
520534
#ifdef LITE_WITH_XPU
521-
lite::TargetWrapperXPU::SetDev(dev_no);
535+
reinterpret_cast<lite::XPURunTimeOption *>(
536+
target_configs()[TARGET(kXPU)].get())
537+
->xpu_dev_num = dev_no;
522538
#else
523539
LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is "
524540
"ignored, please rebuild it with LITE_WITH_XPU=ON.";
@@ -527,7 +543,9 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
527543

528544
void CxxConfig::enable_xpu_multi_stream() {
529545
#ifdef LITE_WITH_XPU
530-
lite::TargetWrapperXPU::enable_xpu_multi_stream();
546+
reinterpret_cast<lite::XPURunTimeOption *>(
547+
target_configs()[TARGET(kXPU)].get())
548+
->xpu_enable_multi_stream = true;
531549
#else
532550
LOG(WARNING)
533551
<< "The invoking of the function 'enable_xpu_stream_per_thread' is "
@@ -596,7 +614,9 @@ void CxxConfig::set_xpu_conv_autotune(bool autotune,
596614

597615
void CxxConfig::set_xpu_cluster_num(const int num) {
598616
#ifdef LITE_WITH_XPU
599-
lite::TargetWrapperXPU::cluster_num = num;
617+
reinterpret_cast<lite::XPURunTimeOption *>(
618+
target_configs()[TARGET(kXPU)].get())
619+
->xpu_cluster_num = num;
600620
#else
601621
LOG(WARNING) << "The invoking of the function "
602622
"'set_xpu_cluster_num' is ignored, please "
@@ -606,14 +626,40 @@ void CxxConfig::set_xpu_cluster_num(const int num) {
606626

607627
void CxxConfig::set_xpu_sdnn_num(const int num) {
608628
#ifdef LITE_WITH_XPU
609-
lite::TargetWrapperXPU::sdnn_num = num;
629+
reinterpret_cast<lite::XPURunTimeOption *>(
630+
target_configs()[TARGET(kXPU)].get())
631+
->xpu_sdnn_num = num;
610632
#else
611633
LOG(WARNING) << "The invoking of the function "
612634
"'set_xpu_sdnn_num' is ignored, please "
613635
"rebuild it with LITE_WITH_XPU=ON.";
614636
#endif
615637
}
616638

639+
void CxxConfig::set_xpu_dump_tensor_path(const std::string dump_tensor_path) {
640+
#ifdef LITE_WITH_XPU
641+
reinterpret_cast<lite::XPURunTimeOption *>(
642+
target_configs()[TARGET(kXPU)].get())
643+
->xpu_dump_tensor_path = dump_tensor_path;
644+
#else
645+
LOG(WARNING) << "The invoking of the function "
646+
"'set_xpu_dump_tensor_path' is ignored, please "
647+
"rebuild it with LITE_WITH_XPU=ON.";
648+
#endif
649+
}
650+
651+
void CxxConfig::set_xpu_dump_log_path(const std::string dump_log_path) {
652+
#ifdef LITE_WITH_XPU
653+
reinterpret_cast<lite::XPURunTimeOption *>(
654+
target_configs()[TARGET(kXPU)].get())
655+
->xpu_dump_log_path = dump_log_path;
656+
#else
657+
LOG(WARNING) << "The invoking of the function "
658+
"'set_xpu_dump_log_path' is ignored, please "
659+
"rebuild it with LITE_WITH_XPU=ON.";
660+
#endif
661+
}
662+
617663
template <class T>
618664
void CxxConfig::set_preferred_inputs_for_warmup(const int group_idx,
619665
const int tensor_idx,

lite/api/paddle_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class LITE_API ConfigBase {
185185
bool metal_use_memory_reuse_{false};
186186

187187
std::vector<std::string> discarded_passes_{};
188+
std::map<TargetType, std::shared_ptr<void>> target_configs_;
188189

189190
public:
190191
explicit ConfigBase(PowerMode mode = LITE_POWER_NO_BIND, int threads = 1);
@@ -350,6 +351,9 @@ class LITE_API ConfigBase {
350351
// Set external custom allocator
351352
void set_custom_allocator(TargetType target_type,
352353
CustomAllocator custom_allocator);
354+
std::map<TargetType, std::shared_ptr<void>> target_configs() const {
355+
return target_configs_;
356+
}
353357
};
354358

355359
class LITE_API CxxModelBuffer {
@@ -459,6 +463,8 @@ class LITE_API CxxConfig : public ConfigBase {
459463
void set_xpu_sdnn_num(const int num);
460464
void set_xpu_local_quant(bool local_quant = false);
461465
void set_xpu_compute_precision(const std::string& precision = "int16");
466+
void set_xpu_dump_tensor_path(const std::string dump_tensor_path = "");
467+
void set_xpu_dump_log_path(const std::string dump_log_path = "");
462468

463469
// set input tensor for warmup.
464470
// It is optional. If you set prefered_inputs, model wil run immediately when

0 commit comments

Comments
 (0)