Skip to content

Commit 20cfa8b

Browse files
authored
Abstract GenerateDeviceEventFlag to shield platforms (#35219)
* Abstract GenerateDeviceEventFlag to shield platforms * Remove get_cuda_flags
1 parent 31cd106 commit 20cfa8b

File tree

5 files changed

+33
-27
lines changed

5 files changed

+33
-27
lines changed

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void AssociateInputWithEvents(
7777
for (auto var_id : new_event_var_id) {
7878
if (var_id2event->count(var_id) == 0) {
7979
auto device_event = std::make_shared<platform::DeviceEvent>(
80-
place, platform::get_cuda_flags(false, false, false));
80+
place, platform::GenerateDeviceEventFlag());
8181
var_id2event->emplace(var_id, std::move(device_event));
8282
}
8383
// Add events for next_instr.inputs

paddle/fluid/platform/device_event_base.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/platform/device_event_base.h"
1616
#include "paddle/fluid/platform/device_event_cpu.h"
17+
#include "paddle/fluid/platform/event.h"
1718

1819
namespace paddle {
1920
namespace platform {
@@ -25,6 +26,31 @@ EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
2526
EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
2627
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
2728

29+
/*
30+
* Generate flag used to create event on all sorts of equipment.
31+
* NOTE: Support CPU/CUDA/ROCM currently.
32+
*/
33+
unsigned int GenerateDeviceEventFlag(bool enable_timing, bool blocking,
34+
bool interprocess) {
35+
#ifdef PADDLE_WITH_CUDA
36+
unsigned int flags =
37+
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
38+
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
39+
(interprocess ? cudaEventInterprocess : cudaEventDefault);
40+
return flags;
41+
#endif
42+
43+
#ifdef PADDLE_WITH_HIP
44+
unsigned int flags =
45+
(blocking ? hipEventBlockingSync : hipEventDefault) |
46+
(enable_timing ? hipEventDefault : hipEventDisableTiming) |
47+
(interprocess ? hipEventInterprocess : hipEventDefault);
48+
return flags;
49+
#endif
50+
51+
return 0;
52+
}
53+
2854
void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place,
2955
unsigned int flag) {
3056
event->InitEvent(std::make_shared<CPUDeviceEventWrapper>(place, flag));

paddle/fluid/platform/device_event_base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ inline int DeviceTypeToId(const DeviceType& device_type) {
3939
return static_cast<int>(device_type);
4040
}
4141

42+
unsigned int GenerateDeviceEventFlag(bool enable_timing = false,
43+
bool blocking = false,
44+
bool interprocess = false);
45+
4246
enum EventStatus {
4347
INITIALIZED = 0,
4448
SCHEDULED = 1,

paddle/fluid/platform/event.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -195,30 +195,5 @@ class CudaEvent {
195195
#endif
196196
};
197197

198-
static unsigned int get_cuda_flags(bool enable_timing, bool blocking,
199-
bool interprocess) {
200-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
201-
202-
#ifdef PADDLE_WITH_HIP
203-
unsigned int flags =
204-
(blocking ? hipEventBlockingSync : hipEventDefault) |
205-
(enable_timing ? hipEventDefault : hipEventDisableTiming) |
206-
(interprocess ? hipEventInterprocess : hipEventDefault);
207-
return flags;
208-
#else
209-
unsigned int flags =
210-
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
211-
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
212-
(interprocess ? cudaEventInterprocess : cudaEventDefault);
213-
return flags;
214-
#endif
215-
216-
#else
217-
PADDLE_THROW(platform::errors::Unavailable(
218-
"Paddle is not compiled with CUDA. Cannot get the cuda event flags."));
219-
return 0;
220-
#endif
221-
}
222-
223198
} // namespace platform
224199
} // namespace paddle

paddle/fluid/pybind/cuda_streams_py.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <string>
1616
#include <vector>
1717

18+
#include "paddle/fluid/platform/device_event_base.h"
1819
#include "paddle/fluid/platform/event.h"
1920
#include "paddle/fluid/platform/stream/cuda_stream.h"
2021
#include "paddle/fluid/pybind/cuda_streams_py.h"
@@ -331,7 +332,7 @@ void BindCudaStream(py::module *m_ptr) {
331332
[](paddle::platform::CudaEvent &self, bool enable_timing,
332333
bool blocking, bool interprocess) {
333334
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
334-
unsigned int flags = platform::get_cuda_flags(
335+
unsigned int flags = platform::GenerateDeviceEventFlag(
335336
enable_timing, blocking, interprocess);
336337
new (&self) paddle::platform::CudaEvent(flags);
337338
#else

0 commit comments

Comments
 (0)