Skip to content

Commit 80d7a86

Browse files
authored
[NPU] fix storage_properties type mismatch with OneDNN and NPU (#60566)
1 parent da71db0 commit 80d7a86

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

paddle/phi/common/place.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ const char *AllocationTypeStr(AllocationType type) {
3737
return "xpu";
3838
case AllocationType::IPU:
3939
return "ipu";
40+
case AllocationType::CUSTOM:
41+
return "custom_device";
4042
default:
4143
PD_THROW("Invalid phi device type.");
4244
return {};

paddle/phi/core/dense_tensor.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,25 @@ template const XPUStorageProperties& DenseTensor::storage_properties() const;
304304
#endif
305305

306306
bool DenseTensor::storage_properties_initialized() const {
307-
return storage_properties_ != nullptr;
307+
if (storage_properties_ == nullptr) {
308+
return false;
309+
} else if (NPUStorageProperties::classof(storage_properties_.get())) {
310+
return place().GetType() == AllocationType::CUSTOM;
311+
#ifdef PADDLE_WITH_XPU
312+
} else if (XPUStorageProperties::classof(storage_properties_.get())) {
313+
return place().GetType() == AllocationType::XPU;
314+
#endif
315+
#ifdef PADDLE_WITH_DNNL
316+
} else if (OneDNNStorageProperties::classof(storage_properties_.get())) {
317+
return place().GetType() == AllocationType::CPU;
318+
#endif
319+
} else {
320+
PADDLE_THROW(
321+
phi::errors::InvalidArgument("The type of storage_properties [%s] is "
322+
"inconsistent with tensor place [%s]",
323+
storage_properties_->type_info().name(),
324+
AllocationTypeStr(place().GetType())));
325+
}
308326
}
309327

310328
void DenseTensor::set_storage_properties(

0 commit comments

Comments
 (0)