Skip to content

Commit adf4977

Browse files
committed
add fallback to provider_normal
Signed-off-by: Mateusz P. Nowak <[email protected]>
1 parent 48e9075 commit adf4977

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

unified-runtime/source/adapters/level_zero/v2/context.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,20 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
8686
auto device = platform->getDeviceById(deviceId);
8787

8888
// TODO: just use per-context id?
89-
return std::make_unique<v2::provider_counter>(
90-
platform, context, v2::QUEUE_IMMEDIATE, device, flags);
89+
return v2::createProvider(platform, context, v2::QUEUE_IMMEDIATE,
90+
device, flags);
91+
}),
92+
eventPoolCacheRegular(
93+
this, phDevices[0]->Platform->getNumDevices(),
94+
[context = this, platform = phDevices[0]->Platform](
95+
DeviceId deviceId,
96+
v2::event_flags_t flags) -> std::unique_ptr<v2::event_provider> {
97+
auto device = platform->getDeviceById(deviceId);
98+
99+
// TODO: just use per-context id?
100+
return v2::createProvider(platform, context, v2::QUEUE_REGULAR,
101+
device, flags);
91102
}),
92-
eventPoolCacheRegular(this, phDevices[0]->Platform->getNumDevices(),
93-
[context = this, platform = phDevices[0]->Platform](
94-
DeviceId deviceId, v2::event_flags_t flags)
95-
-> std::unique_ptr<v2::event_provider> {
96-
std::ignore = deviceId;
97-
std::ignore = platform;
98-
99-
// TODO: just use per-context id?
100-
return std::make_unique<v2::provider_normal>(
101-
context, v2::QUEUE_REGULAR, flags);
102-
}),
103103
nativeEventsPool(this, std::make_unique<v2::provider_normal>(
104104
this, v2::QUEUE_IMMEDIATE,
105105
v2::EVENT_FLAGS_PROFILING_ENABLED)),

unified-runtime/source/adapters/level_zero/v2/event_provider_counter.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "context.hpp"
1414
#include "event_provider.hpp"
1515
#include "event_provider_counter.hpp"
16+
#include "event_provider_normal.hpp"
1617
#include "loader/ze_loader.h"
1718

1819
#include "../device.hpp"
@@ -28,9 +29,15 @@ provider_counter::provider_counter(ur_platform_handle_t platform,
2829
: queueType(queueType), flags(flags) {
2930
assert(flags & EVENT_FLAGS_COUNTER);
3031

31-
ZE2UR_CALL_THROWS(zeDriverGetExtensionFunctionAddress,
32-
(platform->ZeDriver, "zexCounterBasedEventCreate2",
33-
(void **)&this->eventCreateFunc));
32+
// Try to get the counter-based event extension function
33+
auto result =
34+
ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress,
35+
(platform->ZeDriver, "zexCounterBasedEventCreate2",
36+
(void **)&this->eventCreateFunc));
37+
if (result != ZE_RESULT_SUCCESS) {
38+
throw ur_result_t(ze2urResult(result));
39+
}
40+
3441
ZE2UR_CALL_THROWS(zelLoaderTranslateHandle,
3542
(ZEL_HANDLE_CONTEXT, context->getZeHandle(),
3643
(void **)&translatedContext));
@@ -88,4 +95,21 @@ raii::cache_borrowed_event provider_counter::allocate() {
8895

8996
event_flags_t provider_counter::eventFlags() const { return flags; }
9097

98+
std::unique_ptr<event_provider> createProvider(ur_platform_handle_t platform,
99+
ur_context_handle_t context,
100+
queue_type queueType,
101+
ur_device_handle_t device,
102+
event_flags_t flags) {
103+
// Try to create a counter-based event provider first
104+
try {
105+
return std::make_unique<provider_counter>(platform, context, queueType,
106+
device, flags);
107+
} catch (...) {
108+
// If counter-based events are not supported, fall back to normal events
109+
// Remove the counter flag as the normal provider doesn't support it
110+
event_flags_t normalFlags = flags & ~EVENT_FLAGS_COUNTER;
111+
return std::make_unique<provider_normal>(context, queueType, normalFlags);
112+
}
113+
}
114+
91115
} // namespace v2

unified-runtime/source/adapters/level_zero/v2/event_provider_counter.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,12 @@ class provider_counter : public event_provider {
5454
std::vector<raii::ze_event_handle_t> freelist;
5555
};
5656

57+
// Factory function that creates a counter-based provider with fallback to
58+
// normal provider
59+
std::unique_ptr<event_provider> createProvider(ur_platform_handle_t platform,
60+
ur_context_handle_t context,
61+
queue_type queueType,
62+
ur_device_handle_t device,
63+
event_flags_t flags);
64+
5765
} // namespace v2

0 commit comments

Comments
 (0)