Skip to content

Commit c369baf

Browse files
modify the macro of cuda stream and event
1 parent 0ff122a commit c369baf

File tree

4 files changed

+70
-36
lines changed

4 files changed

+70
-36
lines changed

paddle/fluid/platform/event.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class MemEvent {
117117
std::string annotation_;
118118
};
119119

120-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
121120
class CudaEvent {
121+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
122122
public:
123123
CudaEvent() { cudaEventCreateWithFlags(&event_, flags_); }
124124

@@ -151,17 +151,23 @@ class CudaEvent {
151151
private:
152152
unsigned int flags_ = cudaEventDefault;
153153
gpuEvent_t event_;
154+
#endif
154155
};
155156

156157
static unsigned int get_cuda_flags(bool enable_timing, bool blocking,
157158
bool interprocess) {
159+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
158160
unsigned int flags =
159161
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
160162
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
161163
(interprocess ? cudaEventInterprocess : cudaEventDefault);
162164
return flags;
163-
}
165+
#else
166+
PADDLE_THROW(platform::errors::Unavailable(
167+
"Paddle is not compiled with CUDA. Cannot get the cuda event flags."));
168+
return 0;
164169
#endif
170+
}
165171

166172
} // namespace platform
167173
} // namespace paddle

paddle/fluid/platform/stream/cuda_stream.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ void CUDAStream::Wait() const {
9696
PADDLE_ENFORCE_CUDA_SUCCESS(e_sync);
9797
}
9898

99-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
10099
CUDAStream* get_current_stream(int deviceId) {
100+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
101101
if (deviceId == -1) {
102102
deviceId = platform::GetCurrentDeviceId();
103103
}
@@ -111,8 +111,12 @@ CUDAStream* get_current_stream(int deviceId) {
111111
->Stream()
112112
.get();
113113
return stream;
114-
}
114+
#else
115+
PADDLE_THROW(platform::errors::Unavailable(
116+
"Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
117+
return nullptr;
115118
#endif
119+
}
116120

117121
} // namespace stream
118122
} // namespace platform

paddle/fluid/platform/stream/cuda_stream.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ enum class Priority : uint8_t {
3333
kHigh = 0x1,
3434
kNormal = 0x2,
3535
};
36-
36+
#endif
3737
class CUDAStream final {
38+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
3839
public:
3940
CUDAStream() = default;
4041
explicit CUDAStream(const Place& place,
@@ -119,12 +120,11 @@ class CUDAStream final {
119120
#endif
120121
Priority priority_{Priority::kNormal};
121122
std::unique_ptr<StreamCallbackManager<gpuStream_t>> callback_manager_;
122-
123+
#endif
123124
DISABLE_COPY_AND_ASSIGN(CUDAStream);
124125
};
125126

126127
CUDAStream* get_current_stream(int deviceId);
127-
#endif
128128

129129
} // namespace stream
130130
} // namespace platform

paddle/fluid/pybind/imperative.cc

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -610,13 +610,19 @@ void BindImperative(py::module *m_ptr) {
610610
[](const std::shared_ptr<imperative::Tracer> &tracer) {
611611
imperative::SetCurrentTracer(tracer);
612612
});
613-
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
614613
m.def("_get_current_stream",
615614
[](int deviceId) {
615+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
616616
return paddle::platform::stream::get_current_stream(deviceId);
617+
#else
618+
PADDLE_THROW(platform::errors::Unavailable(
619+
"Paddle is not compiled with CUDA. Cannot visit cuda current "
620+
"stream."));
621+
#endif
617622
},
618623
py::return_value_policy::reference);
619624
m.def("_device_synchronize", [](int device_id) {
625+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
620626
if (device_id == -1) {
621627
device_id = paddle::platform::GetCurrentDeviceId();
622628
}
@@ -625,8 +631,11 @@ void BindImperative(py::module *m_ptr) {
625631
paddle::platform::SetDeviceId(device_id);
626632
PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize());
627633
paddle::platform::SetDeviceId(curr_device_id);
628-
});
634+
#else
635+
PADDLE_THROW(platform::errors::Unavailable(
636+
"Paddle is not compiled with CUDA. Cannot visit device synchronize."));
629637
#endif
638+
});
630639

631640
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
632641
m, "VarBase", R"DOC()DOC")
@@ -1692,24 +1701,28 @@ void BindImperative(py::module *m_ptr) {
16921701
return imperative::PyLayerApply(place, cls, args, kwargs);
16931702
});
16941703

1695-
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
16961704
py::class_<paddle::platform::stream::CUDAStream>(m, "CUDAStream")
1697-
.def("__init__",
1698-
[](paddle::platform::stream::CUDAStream &self,
1699-
platform::CUDAPlace &device, int priority) {
1700-
if (priority != 1 && priority != 2) {
1701-
PADDLE_THROW(platform::errors::InvalidArgument(
1702-
"Priority should be 1(high) or 2(normal) "));
1703-
}
1704-
auto prio = paddle::platform::stream::Priority(priority);
1705+
.def("__init__", [](paddle::platform::stream::CUDAStream &self,
1706+
platform::CUDAPlace &device, int priority) {
1707+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1708+
if (priority != 1 && priority != 2) {
1709+
PADDLE_THROW(platform::errors::InvalidArgument(
1710+
"Priority should be 1(high) or 2(normal) "));
1711+
}
1712+
auto prio = paddle::platform::stream::Priority(priority);
17051713

1706-
new (&self) paddle::platform::stream::CUDAStream(device, prio);
1707-
})
1708-
.def("wait_event",
1709-
[](paddle::platform::stream::CUDAStream &self,
1710-
paddle::platform::CudaEvent &event) {
1711-
self.WaitEvent(event.GetRawCudaEvent());
1712-
})
1714+
new (&self) paddle::platform::stream::CUDAStream(device, prio);
1715+
#else
1716+
PADDLE_THROW(platform::errors::Unavailable(
1717+
"Class CUDAStream can only be initialized on the GPU platform."));
1718+
#endif
1719+
});
1720+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1721+
m.def("wait_event",
1722+
[](paddle::platform::stream::CUDAStream &self,
1723+
paddle::platform::CudaEvent &event) {
1724+
self.WaitEvent(event.GetRawCudaEvent());
1725+
})
17131726
.def("wait_stream",
17141727
[](paddle::platform::stream::CUDAStream &self,
17151728
paddle::platform::stream::CUDAStream &stream) {
@@ -1736,23 +1749,34 @@ void BindImperative(py::module *m_ptr) {
17361749
return event;
17371750

17381751
});
1752+
#endif
17391753

17401754
py::class_<paddle::platform::CudaEvent>(m, "CUDAEvent")
17411755
.def("__init__",
1742-
[](paddle::platform::CudaEvent &self, bool enable_timing = false,
1743-
bool blocking = false, bool interprocess = false) {
1756+
[](paddle::platform::CudaEvent &self, bool enable_timing,
1757+
bool blocking, bool interprocess) {
1758+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
17441759
unsigned int flags = platform::get_cuda_flags(
17451760
enable_timing, blocking, interprocess);
17461761
new (&self) paddle::platform::CudaEvent(flags);
1747-
})
1748-
.def("record",
1749-
[](paddle::platform::CudaEvent &self,
1750-
paddle::platform::stream::CUDAStream *stream) {
1751-
if (stream == nullptr) {
1752-
stream = paddle::platform::stream::get_current_stream(-1);
1753-
}
1754-
self.Record(*stream);
1755-
})
1762+
#else
1763+
PADDLE_THROW(platform::errors::Unavailable(
1764+
"Class CUDAEvent can only be initialized on the GPU "
1765+
"platform."));
1766+
1767+
#endif
1768+
},
1769+
py::arg("enable_timing") = false, py::arg("blocking") = false,
1770+
py::arg("interprocess") = false);
1771+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1772+
m.def("record",
1773+
[](paddle::platform::CudaEvent &self,
1774+
paddle::platform::stream::CUDAStream *stream) {
1775+
if (stream == nullptr) {
1776+
stream = paddle::platform::stream::get_current_stream(-1);
1777+
}
1778+
self.Record(*stream);
1779+
})
17561780
.def("query",
17571781
[](paddle::platform::CudaEvent &self) { return self.Query(); })
17581782
.def("synchronize",

0 commit comments

Comments
 (0)