Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions paddle/phi/core/dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ DenseTensor::DenseTensor(const DenseTensor& other) {
storage_properties_ =
std::move(CopyStorageProperties(other.storage_properties_));
inplace_version_counter_ = other.inplace_version_counter_;

#ifdef PADDLE_WITH_DNNL
mem_desc_ = other.mem_desc_;
#endif
}

DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
Expand All @@ -74,9 +70,6 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
storage_properties_ =
std::move(CopyStorageProperties(other.storage_properties_));
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_DNNL
mem_desc_ = other.mem_desc_;
#endif
return *this;
}

Expand All @@ -85,9 +78,6 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) noexcept {
std::swap(holder_, other.holder_);
storage_properties_ = std::move(other.storage_properties_);
std::swap(inplace_version_counter_, other.inplace_version_counter_);
#ifdef PADDLE_WITH_DNNL
mem_desc_ = other.mem_desc_;
#endif
return *this;
}

Expand Down
18 changes: 0 additions & 18 deletions paddle/phi/core/dense_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/utils/test_macros.h"

/* @jim19930609: Move to MKLDNN_Tensor in the future
*/
#ifdef PADDLE_WITH_DNNL
#include "dnnl.hpp" // NOLINT
#endif

namespace phi {

class DenseTensorUtils;
Expand Down Expand Up @@ -290,18 +284,6 @@ class TEST_API DenseTensor : public TensorBase,
std::shared_ptr<InplaceVersion> inplace_version_counter_ =
std::make_shared<InplaceVersion>();

/* @jim19930609: This is a hack
In general, it is badly designed to fuse MKLDNN-specific objects into a
generic Tensor.
We temporarily leave them here to unblock Tensor Unification progress.
In the final state, we should come up with a MKLDNN_Tensor and move the
following codes there.
*/
#ifdef PADDLE_WITH_DNNL
/// \brief memory descriptor of tensor which have layout set as kMKLDNN
dnnl::memory::desc mem_desc_;
#endif

#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/phi/core/dense_tensor.inl"
#endif
Expand Down
5 changes: 1 addition & 4 deletions paddle/phi/core/dense_tensor.inl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ following codes there.
public:
const dnnl::memory::desc& mem_desc() const;

inline void set_mem_desc(const dnnl::memory::desc& mem_desc) {
mem_desc_ = mem_desc;
meta_.layout = DataLayout::ONEDNN;
}
void set_mem_desc(const dnnl::memory::desc& mem_desc);

#endif

Expand Down
27 changes: 23 additions & 4 deletions paddle/phi/core/dense_tensor_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,29 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
}

#ifdef PADDLE_WITH_DNNL
const dnnl::memory::desc& DenseTensor::mem_desc() const { return mem_desc_; }
const dnnl::memory::desc& DenseTensor::mem_desc() const {
if (storage_properties_ == nullptr) {
std::unique_ptr<StorageProperties>* storage_properties_ptr =
const_cast<std::unique_ptr<StorageProperties>*>(&storage_properties_);
*storage_properties_ptr = std::make_unique<OneDNNStorageProperties>();
}
return this->storage_properties<OneDNNStorageProperties>().mem_desc;
}

void DenseTensor::set_mem_desc(const dnnl::memory::desc& mem_desc) {
if (storage_properties_ == nullptr) {
storage_properties_ = std::make_unique<OneDNNStorageProperties>();
}
if (OneDNNStorageProperties::classof(storage_properties_.get())) {
dynamic_cast<OneDNNStorageProperties*>(storage_properties_.get())
->mem_desc = mem_desc;
meta_.layout = DataLayout::ONEDNN;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The actual type of storage_properties is inconsistent with the type "
"of the template parameter passed in."));
}
}
#endif

// NOTE: For historical reasons, this interface has a special behavior,
Expand All @@ -394,9 +416,6 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
meta_.strides = src.meta_.strides;
storage_properties_ =
std::move(CopyStorageProperties(src.storage_properties_));
#ifdef PADDLE_WITH_DNNL
mem_desc_ = src.mem_desc_;
#endif
return *this;
}

Expand Down