Skip to content

Commit 319d536

Browse files
committed
refine
1 parent 2682045 commit 319d536

5 files changed

Lines changed: 27 additions & 157 deletions

File tree

paddle/phi/backends/onednn/onednn_reuse.h

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -318,28 +318,16 @@ class OneDNNHandlerT {
318318
typename std::enable_if<std::is_same<typename std::decay<First>::type,
319319
dnnl::primitive_attr>::value>::type
320320
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
321-
try {
322-
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
323-
engine_, std::forward<Args>(args)..., first);
324-
} catch (std::exception& ex) {
325-
LOG(WARNING) << ex.what();
326-
PADDLE_THROW(phi::errors::Unavailable("wanghuan7"));
327-
std::rethrow_exception(std::current_exception());
328-
}
321+
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
322+
engine_, std::forward<Args>(args)..., first);
329323
}
330324

331325
template <class First, class... Args>
332326
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
333327
dnnl::primitive_attr>::value>::type
334328
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
335-
try {
336-
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
337-
engine_, std::forward<First>(first), std::forward<Args>(args)...);
338-
} catch (std::exception& ex) {
339-
LOG(WARNING) << ex.what();
340-
PADDLE_THROW(phi::errors::Unavailable("wanghuan8"));
341-
std::rethrow_exception(std::current_exception());
342-
}
329+
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
330+
engine_, std::forward<First>(first), std::forward<Args>(args)...);
343331
}
344332

345333
template <typename... Args>
@@ -354,14 +342,8 @@ class OneDNNHandlerT {
354342
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
355343
dev_ctx_.GetBlob(key_pd));
356344
if (bwd_pd_ == nullptr) {
357-
try {
358-
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
359-
engine_, std::forward<Args>(args)..., *fwd_pd_);
360-
} catch (std::exception& ex) {
361-
LOG(WARNING) << ex.what();
362-
PADDLE_THROW(phi::errors::Unavailable("wanghuan1"));
363-
std::rethrow_exception(std::current_exception());
364-
}
345+
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
346+
engine_, std::forward<Args>(args)..., *fwd_pd_);
365347
dev_ctx_.SetBlob(key_pd, bwd_pd_);
366348
}
367349
}
@@ -379,14 +361,8 @@ class OneDNNHandlerT {
379361
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
380362
dev_ctx_.GetBlob(key_pd));
381363
if (bwd_w_pd_ == nullptr) {
382-
try {
383-
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
384-
engine_, std::forward<Args>(args)..., *fwd_pd_);
385-
} catch (std::exception& ex) {
386-
LOG(WARNING) << ex.what();
387-
PADDLE_THROW(phi::errors::Unavailable("wanghuan2"));
388-
std::rethrow_exception(std::current_exception());
389-
}
364+
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
365+
engine_, std::forward<Args>(args)..., *fwd_pd_);
390366
dev_ctx_.SetBlob(key_pd, bwd_w_pd_);
391367
}
392368
}
@@ -645,28 +621,16 @@ class OneDNNHandlerNoCachingT {
645621
typename std::enable_if<std::is_same<typename std::decay<First>::type,
646622
dnnl::primitive_attr>::value>::type
647623
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
648-
try {
649-
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
650-
engine_, std::forward<Args>(args)..., first);
651-
} catch (std::exception& ex) {
652-
LOG(WARNING) << ex.what();
653-
PADDLE_THROW(phi::errors::Unavailable("wanghuan3"));
654-
std::rethrow_exception(std::current_exception());
655-
}
624+
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
625+
engine_, std::forward<Args>(args)..., first);
656626
}
657627

658628
template <class First, class... Args>
659629
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
660630
dnnl::primitive_attr>::value>::type
661631
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
662-
try {
663-
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
664-
engine_, std::forward<First>(first), std::forward<Args>(args)...);
665-
} catch (std::exception& ex) {
666-
LOG(WARNING) << ex.what();
667-
PADDLE_THROW(phi::errors::Unavailable("wanghuan4"));
668-
std::rethrow_exception(std::current_exception());
669-
}
632+
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
633+
engine_, std::forward<First>(first), std::forward<Args>(args)...);
670634
}
671635

672636
template <typename... Args>
@@ -676,14 +640,8 @@ class OneDNNHandlerNoCachingT {
676640
PADDLE_ENFORCE_NOT_NULL(
677641
fwd_pd_,
678642
errors::Unavailable("Get oneDNN Forward primitive %s failed."));
679-
try {
680-
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
681-
engine_, std::forward<Args>(args)..., *fwd_pd_);
682-
} catch (std::exception& ex) {
683-
LOG(WARNING) << ex.what();
684-
PADDLE_THROW(phi::errors::Unavailable("wanghuan5"));
685-
std::rethrow_exception(std::current_exception());
686-
}
643+
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
644+
engine_, std::forward<Args>(args)..., *fwd_pd_);
687645
}
688646

689647
template <typename... Args>
@@ -695,14 +653,8 @@ class OneDNNHandlerNoCachingT {
695653
errors::Unavailable("Get oneDNN Forward primitive %s failed."));
696654
auto bwd_desc =
697655
typename TBackward_params::desc(std::forward<Args>(args)...);
698-
try {
699-
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
700-
bwd_desc, engine_, *fwd_pd_);
701-
} catch (std::exception& ex) {
702-
LOG(WARNING) << ex.what();
703-
PADDLE_THROW(phi::errors::Unavailable("wanghuan6"));
704-
std::rethrow_exception(std::current_exception());
705-
}
656+
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
657+
bwd_desc, engine_, *fwd_pd_);
706658
}
707659

708660
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(

paddle/phi/core/dense_tensor_impl.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ const dnnl::memory::desc& DenseTensor::mem_desc() const {
382382
std::unique_ptr<StorageProperties>* storage_properties_ptr =
383383
const_cast<std::unique_ptr<StorageProperties>*>(&storage_properties_);
384384
*storage_properties_ptr = std::make_unique<OneDNNStorageProperties>();
385+
static_cast<OneDNNStorageProperties*>(storage_properties_ptr->get())
386+
->mem_desc = dnnl::memory::desc();
385387
}
386388
return this->storage_properties<OneDNNStorageProperties>().mem_desc;
387389
}

0 commit comments

Comments
 (0)