Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lite/backends/xpu/target_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ void TargetWrapperXPU::FreeL3Cache() {
}

// xpu context
LITE_THREAD_LOCAL xdnn::Context* TargetWrapperXPU::tls_raw_ctx_{nullptr};
LITE_THREAD_LOCAL std::shared_ptr<xdnn::Context> TargetWrapperXPU::tls_raw_ctx_{
nullptr};
// multi encoder config
LITE_THREAD_LOCAL std::string
TargetWrapperXPU::multi_encoder_precision; // NOLINT
Expand Down
10 changes: 5 additions & 5 deletions lite/backends/xpu/target_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class TargetWrapper<TARGET(kXPU)> {
static XPUScratchPadGuard MallocScratchPad(size_t size);

static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) {
tls_raw_ctx_ = xdnn::create_context();
CHECK(tls_raw_ctx_);
if (tls_raw_ctx_.get() == nullptr) {
tls_raw_ctx_.reset(xdnn::create_context(), xdnn::destroy_context);
CHECK(tls_raw_ctx_.get());
if (l3_planner_ == nullptr) {
l3_planner_ = new XPUL3Planner;
}
Expand Down Expand Up @@ -121,7 +121,7 @@ class TargetWrapper<TARGET(kXPU)> {
}
}
}
return tls_raw_ctx_;
return tls_raw_ctx_.get();
}
static void MallocL3Cache(
const std::vector<std::vector<int64_t>>& query_shape);
Expand Down Expand Up @@ -161,7 +161,7 @@ class TargetWrapper<TARGET(kXPU)> {
void* l3_ptr,
size_t l3_size,
const std::vector<std::vector<int64_t>>& query_shape);
static LITE_THREAD_LOCAL xdnn::Context* tls_raw_ctx_;
static LITE_THREAD_LOCAL std::shared_ptr<xdnn::Context> tls_raw_ctx_;
static LITE_THREAD_LOCAL void* local_l3_ptr_;
static void* shared_l3_ptr_;
static std::mutex mutex_l3_;
Expand Down