diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 35f165a59e10e4..444f88c77bfbc7 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -360,6 +360,14 @@ void BindPlace(pybind11::module &m) { // NOLINT [](const phi::CustomPlace &self) { return self.GetDeviceType(); }) .def("__repr__", string::to_string) .def("__str__", string::to_string); +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + m.def("is_float16_supported", [](const phi::CustomPlace &place) -> bool { + return phi::DeviceManager::IsFloat16Supported(place); + }); + m.def("is_bfloat16_supported", [](const phi::CustomPlace &place) -> bool { + return phi::DeviceManager::IsBFloat16Supported(place); + }); +#endif py::class_ cudaplace(m, "CUDAPlace", R"DOC( CUDAPlace is a descriptor of a device. diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 0e0d9b9e3aa83a..12ef9f995e7f29 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -628,6 +628,26 @@ class CustomDevice : public DeviceInterface { return grid_dim_size; } + bool IsFloat16Supported(size_t dev_id) { + const auto device = &devices_pool[dev_id]; + bool supported = false; + if (pimpl_->is_float16_supported) { + pimpl_->is_float16_supported(device, &supported); + } + VLOG(10) << Type() << " is float16 supported: " << supported; + return supported; + } + + bool IsBFloat16Supported(size_t dev_id) { + const auto device = &devices_pool[dev_id]; + bool supported = false; + if (pimpl_->is_bfloat16_supported) { + pimpl_->is_bfloat16_supported(device, &supported); + } + VLOG(10) << Type() << " is bfloat16 supported: " << false; + return supported; + } + void* InitEigenDevice(const Place& place, phi::stream::stream_t stream, phi::Allocator* allocator) override { diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index 56051aa2fbe53a..1405cb82087ad1 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -37,7 +37,7 @@ size_t DeviceInterface::GetComputeCapability(size_t dev_id) { } DeviceProp& DeviceInterface::GetDeviceProperties(size_t dev_id) { - DeviceProp prop; + static DeviceProp prop; VLOG(10) << Type() << " get device properties " << 0; return prop; } @@ -73,6 +73,16 @@ std::array DeviceInterface::GetMaxGridDimSize(size_t dev_id) { return {0, 0, 0}; } +bool DeviceInterface::IsFloat16Supported(size_t dev_id) { + VLOG(10) << Type() << " is float16 supported: " << false; + return false; +} + +bool DeviceInterface::IsBFloat16Supported(size_t dev_id) { + VLOG(10) << Type() << " is bfloat16 supported: " << false; + return false; +} + void* DeviceInterface::InitEigenDevice(const Place& place, phi::stream::stream_t stream, phi::Allocator* allocator) { diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index 563e3b99e51174..2a198797aa6c8b 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -79,6 +79,10 @@ class DeviceInterface { // Driver / Runtime virtual std::array GetMaxGridDimSize(size_t dev_id); + virtual bool IsFloat16Supported(size_t dev_id); + + virtual bool IsBFloat16Supported(size_t dev_id); + virtual void* InitEigenDevice(const Place& place, phi::stream::stream_t stream, phi::Allocator* allocator); diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index 87f2af1361b38b..ddd1120723661c 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -590,6 +590,20 @@ struct C_DeviceInterface { C_Status (*get_max_grid_dim_size)(const C_Device device, std::array* grid_dim_size); + /** + * @brief Is float16 supported + * + * @param[C_Device, bool*] device, supported + */ + C_Status (*is_float16_supported)(const C_Device device, bool* supported); + + /** + * @brief Is bfloat16 supported + * + * @param[C_Device, bool*] device, supported + */ + C_Status (*is_bfloat16_supported)(const C_Device device, bool* supported); + /** * @brief init eigen device * diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 6162931fe0982f..220b472c9af3d4 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -523,6 +523,20 @@ std::array DeviceManager::GetMaxGridDimSize( return dev_impl->GetMaxGridDimSize(device_id); } +bool DeviceManager::IsFloat16Supported(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->IsFloat16Supported(device_id); +} + +bool DeviceManager::IsBFloat16Supported(const Place& place) { + auto device_type = place.GetDeviceType(); + auto device_id = place.GetDeviceId(); + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->IsBFloat16Supported(device_id); +} + void* DeviceManager::InitEigenDevice(const Place& place, phi::stream::stream_t stream, phi::Allocator* allocator) { diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index eeeb93087e5d0b..0e418e4b635754 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -186,6 +186,10 @@ class DeviceManager { static std::array GetMaxGridDimSize(const Place& place); + static bool IsFloat16Supported(const Place& place); + + static bool IsBFloat16Supported(const Place& place); + static void* InitEigenDevice(const Place& place, phi::stream::stream_t stream, phi::Allocator* allocator);