diff --git a/paddle/phi/common/place.cc b/paddle/phi/common/place.cc index 008f45aa935544..4057373e73bd8f 100644 --- a/paddle/phi/common/place.cc +++ b/paddle/phi/common/place.cc @@ -37,6 +37,8 @@ const char *AllocationTypeStr(AllocationType type) { return "xpu"; case AllocationType::IPU: return "ipu"; + case AllocationType::CUSTOM: + return "custom_device"; default: PD_THROW("Invalid phi device type."); return {}; diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 1181a812669762..205bd0614bf3e3 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -304,7 +304,25 @@ template const XPUStorageProperties& DenseTensor::storage_properties() const; #endif bool DenseTensor::storage_properties_initialized() const { - return storage_properties_ != nullptr; + if (storage_properties_ == nullptr) { + return false; + } else if (NPUStorageProperties::classof(storage_properties_.get())) { + return place().GetType() == AllocationType::CUSTOM; +#ifdef PADDLE_WITH_XPU + } else if (XPUStorageProperties::classof(storage_properties_.get())) { + return place().GetType() == AllocationType::XPU; +#endif +#ifdef PADDLE_WITH_DNNL + } else if (OneDNNStorageProperties::classof(storage_properties_.get())) { + return place().GetType() == AllocationType::CPU; +#endif + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The type of storage_properties [%s] is " + "inconsistent with tensor place [%s]", + storage_properties_->type_info().name(), + AllocationTypeStr(place().GetType()))); + } } void DenseTensor::set_storage_properties(