@@ -31,31 +31,16 @@ namespace xpu = baidu::xpu::api;
3131namespace phi {
3232
3333struct XPUContext ::Impl {
34- void SetL3Cache (int l3_size = 14155776 ) {
35- const int MAX_XPU_NUM = 16 ;
36- static void * l3ptrs[MAX_XPU_NUM] = {nullptr };
37-
38- if (std::getenv (" XPU_PADDLE_L3_SIZE" ) != nullptr ) {
39- l3_size = atoi (std::getenv (" XPU_PADDLE_L3_SIZE" ));
40- }
41-
42- auto selected_xpus = backends::xpu::GetXPUSelectedDevices ();
43- for (unsigned int i = 0 ; i < selected_xpus.size (); i++) {
44- if (place_.GetDeviceId () == selected_xpus[i]) {
45- if (l3ptrs[place_.GetDeviceId ()] != nullptr ) {
46- xpu_free (l3ptrs[place_.GetDeviceId ()]);
47- l3ptrs[place_.GetDeviceId ()] = nullptr ;
48- }
49- xpu_malloc (static_cast <void **>(&l3ptrs[place_.GetDeviceId ()]),
50- l3_size,
51- XPU_MEM_L3);
52- if (l3ptrs[place_.GetDeviceId ()] != nullptr ) {
53- context_->_l3_mgr .set (l3ptrs[place_.GetDeviceId ()], l3_size);
54- VLOG (3 ) << " xpu place " << static_cast <int >(place_.GetDeviceId ())
55- << " set l3 size " << l3_size;
56- }
57- break ;
58- }
34+ void SetL3Cache (int l3_size = 1024 ) {
35+ PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait (context_->xpu_stream ));
36+ context_->_l3_mgr .set (nullptr , 0 , true ); // free origin l3
37+ void * l3_ptr = nullptr ;
38+ xpu_malloc (static_cast <void **>(&l3_ptr), l3_size, XPU_MEM_L3);
39+
40+ if (l3_ptr != nullptr ) {
41+ VLOG (3 ) << " xpu place " << static_cast <int >(place_.GetDeviceId ())
42+ << " context " << context_ << " set l3 size " << l3_size;
43+ context_->_l3_mgr .set (l3_ptr, l3_size, true );
5944 }
6045 }
6146
@@ -145,28 +130,26 @@ struct XPUContext::Impl {
145130 }
146131 }
147132
148- void Init () {
133+ void Init (int gm_default_size = 1024 , int l3_default_size = 1024 ) {
149134 owned_ = true ;
150135 backends::xpu::XPUDeviceGuard guard (place_.GetDeviceId ());
151136 LOG_FIRST_N (WARNING, 1 )
152137 << " Please NOTE: xpu device: " << static_cast <int >(place_.device );
138+
153139 context_ = xpu::create_context ();
154- // Setup XPU GM Buffer
155- if (std::getenv (" XPUAPI_DEFAULT_SIZE" ) != nullptr ) {
156- context_->set_option (" XPUAPI_DEFAULT_SIZE" ,
157- std::getenv (" XPUAPI_DEFAULT_SIZE" ));
158- } else {
159- // Optimization described in
160- // https://github.com/PaddlePaddle/Paddle/pull/54674
161- context_->set_option (" XPUAPI_DEFAULT_SIZE" , " 1" );
162- }
140+ context_->set_option (" XPUAPI_DEFAULT_SIZE" ,
141+ std::to_string (gm_default_size).c_str ());
142+ VLOG (3 ) << " xpu place " << static_cast <int >(place_.GetDeviceId ())
143+ << " context " << context_ << " set xpuapi_default_size "
144+ << gm_default_size;
145+
163146 if (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL" ) != nullptr ) {
164147 XPUStream s;
165148 xpu_stream_create (&s);
166149 context_->set_stream (s);
167150 }
168151 xpu_version_ = backends::xpu::get_xpu_version (place_.device );
169- SetL3Cache ();
152+ SetL3Cache (l3_default_size );
170153 }
171154
172155 void SetXContext (xpu::Context* context) {
@@ -239,27 +222,61 @@ struct XPUContext::Impl {
239222 xpu::BKCLContext_t bkcl_context_{nullptr };
240223};
241224
225+ static int get_gm_size (int i) {
226+ int default_size = 1024 ;
227+ if (std::getenv (" XPUAPI_DEFAULT_SIZE" ) != nullptr ) {
228+ default_size = atoi (std::getenv (" XPUAPI_DEFAULT_SIZE" ));
229+ }
230+ std::string cur_env = std::string (" XPUAPI_DEFAULT_SIZE" ) + std::to_string (i);
231+ if (std::getenv (cur_env.c_str ()) != nullptr ) {
232+ default_size = atoi (std::getenv (cur_env.c_str ()));
233+ }
234+ return default_size;
235+ }
236+
237+ static int get_l3_size (int i) {
238+ int default_size = 1024 ;
239+ if (std::getenv (" XPU_PADDLE_L3_SIZE" ) != nullptr ) {
240+ default_size = atoi (std::getenv (" XPU_PADDLE_L3_SIZE" ));
241+ }
242+ std::string cur_env = std::string (" XPU_PADDLE_L3_SIZE" ) + std::to_string (i);
243+ if (std::getenv (cur_env.c_str ()) != nullptr ) {
244+ default_size = atoi (std::getenv (cur_env.c_str ()));
245+ }
246+ return default_size;
247+ }
248+
242249XPUContext::XPUContext () : DeviceContext() {
243250 if (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL" ) != nullptr ) {
244- for (int i = 0 ; i < 4 ; i++) {
251+ int default_num_stream = 4 ;
252+ if (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER" ) != nullptr ) {
253+ default_num_stream =
254+ atoi (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER" ));
255+ }
256+ for (int i = 0 ; i < default_num_stream; i++) {
245257 impls_.push_back (std::make_unique<Impl>());
246- impls_[i]->Init ();
258+ impls_[i]->Init (get_gm_size (i), get_l3_size (i) );
247259 }
248260 } else {
249261 impls_.push_back (std::make_unique<Impl>());
250- impls_[0 ]->Init ();
262+ impls_[0 ]->Init (get_gm_size ( 0 ), get_l3_size ( 0 ) );
251263 }
252264}
253265
254266XPUContext::XPUContext (const XPUPlace& place) : DeviceContext() {
255267 if (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL" ) != nullptr ) {
256- for (int i = 0 ; i < 4 ; i++) {
268+ int default_num_stream = 4 ;
269+ if (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER" ) != nullptr ) {
270+ default_num_stream =
271+ atoi (std::getenv (" XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER" ));
272+ }
273+ for (int i = 0 ; i < default_num_stream; i++) {
257274 impls_.push_back (std::make_unique<Impl>(place));
258- impls_[i]->Init ();
275+ impls_[i]->Init (get_gm_size (i), get_l3_size (i) );
259276 }
260277 } else {
261278 impls_.push_back (std::make_unique<Impl>(place));
262- impls_[0 ]->Init ();
279+ impls_[0 ]->Init (get_gm_size ( 0 ), get_l3_size ( 0 ) );
263280 }
264281}
265282
@@ -303,11 +320,13 @@ void XPUContext::Wait() const {
303320 }
304321}
305322
306- void XPUContext::SetXContext (xpu::Context* context) {
307- impls_[0 ]->SetXContext (context);
323+ void XPUContext::SetXContext (xpu::Context* context, int i ) {
324+ impls_[i ]->SetXContext (context);
308325}
309326
310- void XPUContext::SetL3Cache (int l3_size) { impls_[0 ]->SetL3Cache (l3_size); }
327+ void XPUContext::SetL3Cache (int l3_size, int i) {
328+ impls_[i]->SetL3Cache (l3_size);
329+ }
311330
312331void XPUContext::SetBkclContext (xpu::BKCLContext_t context) {
313332 impls_[0 ]->SetBkclContext (context);
0 commit comments