Skip to content

Commit fba141d

Browse files
authored
[XPU] super big ernie support (#7184)
1 parent 1526903 commit fba141d

5 files changed

Lines changed: 34 additions & 3 deletions

File tree

lite/api/paddle_api.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,16 @@ void CxxConfig::set_xpu_l3_cache_method(size_t l3_size, bool locked) {
547547
#endif
548548
}
549549

550+
void set_xpu_gm_workspace_method(size_t gm_size) {
551+
#ifdef LITE_WITH_XPU
552+
lite::TargetWrapperXPU::local_gm_size = gm_size;
553+
#else
554+
LOG(WARNING) << "The invoking of the function "
555+
"'set_xpu_gm_workspace_method' is ignored, please "
556+
"rebuild it with LITE_WITH_XPU=ON.";
557+
#endif
558+
}
559+
550560
void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
551561
#ifdef LITE_WITH_XPU
552562
lite::TargetWrapperXPU::SetDev(dev_no);

lite/api/paddle_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ class LITE_API CxxConfig : public ConfigBase {
408408
void set_xpu_workspace_l3_size_per_thread(int l3_size = 0x4000000);
409409
void set_xpu_l3_cache_method(size_t l3_size, bool locked = false);
410410

411+
void set_xpu_gm_workspace_method(size_t gm_size);
412+
411413
void set_xpu_conv_autotune(bool autotune = true,
412414
const std::string& autotune_file = "");
413415

lite/backends/xpu/target_wrapper.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ LITE_THREAD_LOCAL std::string TargetWrapperXPU::conv_autotune_file;
180180
LITE_THREAD_LOCAL bool TargetWrapperXPU::need_l3_mutex{false};
181181
LITE_THREAD_LOCAL size_t TargetWrapperXPU::local_l3_size{
182182
std::numeric_limits<size_t>::max()};
183+
LITE_THREAD_LOCAL size_t TargetWrapperXPU::local_gm_size{
184+
0x4000000}; // 64 * 1024 * 1024
183185
LITE_THREAD_LOCAL void* TargetWrapperXPU::local_l3_ptr_{nullptr};
184186
void* TargetWrapperXPU::shared_l3_ptr_{nullptr};
185187
size_t TargetWrapperXPU::shared_l3_size{0};

lite/backends/xpu/target_wrapper.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ class TargetWrapper<TARGET(kXPU)> {
100100
local_l3_size = max_l3_size;
101101
}
102102
CHECK_LE(shared_l3_size, max_l3_size);
103+
if (local_gm_size > 0) {
104+
VLOG(3) << "Try To Malloc Local GM Workspace Size is" << local_gm_size;
105+
void* local_gm_ptr = nullptr;
106+
int ret =
107+
xpu_malloc(reinterpret_cast<void**>(&local_gm_ptr), local_gm_size);
108+
if (ret != 0) {
109+
VLOG(3) << "No Enough GM Workspace For Current Predictor.";
110+
} else {
111+
ret = tls_raw_ctx_->_gm_mgr.set(local_gm_ptr, local_gm_size);
112+
if (ret != 0) {
113+
LOG(WARNING) << "XPU GM Mgr Init Fail, Please Check Configuration.";
114+
XPU_CALL(xpu_free(local_gm_ptr));
115+
local_gm_ptr = nullptr;
116+
}
117+
}
118+
}
103119
}
104120
return tls_raw_ctx_;
105121
}
@@ -131,7 +147,8 @@ class TargetWrapper<TARGET(kXPU)> {
131147
// l3 cache config
132148
static LITE_THREAD_LOCAL bool need_l3_mutex; // model level l3 size
133149
static LITE_THREAD_LOCAL size_t local_l3_size; // model level l3 size
134-
static size_t shared_l3_size; // model level l3 size
150+
static LITE_THREAD_LOCAL size_t local_gm_size;
151+
static size_t shared_l3_size; // model level l3 size
135152
static LITE_THREAD_LOCAL std::vector<XPUL3CacheBlock*>
136153
l3_block_dict; // l3 cache block used between op layers
137154

lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ class XPUMultiEncoderFuser {
10461046
weight_qkv_trans_int8.get(),
10471047
max_f,
10481048
qkv_len);
1049-
memcpy(weight_tensor_vec[0]->mutable_data<float>(),
1049+
memcpy(weight_tensor_vec[0]->mutable_data<int8_t>(),
10501050
weight_qkv_trans_int8.get(),
10511051
qkv_len * sizeof(int8_t));
10521052
} else {
@@ -1056,7 +1056,7 @@ class XPUMultiEncoderFuser {
10561056
weight_qkv_trans_int16.get(),
10571057
max_f,
10581058
qkv_len);
1059-
memcpy(weight_tensor_vec[0]->mutable_data<float>(),
1059+
memcpy(weight_tensor_vec[0]->mutable_data<int16_t>(),
10601060
weight_qkv_trans_int16.get(),
10611061
qkv_len * sizeof(int16_t));
10621062
}

0 commit comments

Comments
 (0)