Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,11 @@ All parameter, weight, gradient are variables in Paddle.
#if (defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
py::class_<platform::CUDAPlace>(m, "CUDAPlace")
py::class_<platform::CUDAPlace>(m, "CUDAPlace", R"DOC(
CUDAPlace is a descriptor of a device. It represents a GPU, and each CUDAPlace
has a dev_id to indicate the number of cards represented by the current CUDAPlace.
The memory of CUDAPlace with different dev_id is not accessible.
)DOC")
.def("__init__",
[](platform::CUDAPlace &self, int dev_id) {
#ifdef PADDLE_WITH_CUDA
Expand All @@ -783,7 +787,10 @@ All parameter, weight, gradient are variables in Paddle.
&IsSamePlace<platform::CUDAPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPlace &>);

py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
CPUPlace is a descriptor of a device. It represents a CPU, and the memory
CPUPlace can be accessed by CPU.
)DOC")
.def(py::init<>())
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CUDAPlace>)
Expand All @@ -792,7 +799,10 @@ All parameter, weight, gradient are variables in Paddle.
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CPUPlace &>);

py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace", R"DOC(
CUDAPinnedPlace is a descriptor of a device. The memory of CUDAPinnedPlace
can be accessed by GPU and CPU.
)DOC")
.def("__init__",
[](platform::CUDAPinnedPlace &self) {
#ifndef PADDLE_WITH_CUDA
Expand Down