DLPack: sync CUDA stream when exporting to TensorFlow#377
DLPack: sync CUDA stream when exporting to TensorFlow#377
Conversation
|
@njroussel After running the relevant unit tests in a loop on my machine for a while, this seems to resolve the synchronization issue. I wouldn't want to introduce unnecessary synchronization points though; do you also think this is the right solution given that TF uses separate non-default streams? |
|
@merlinND, I'm wondering if this change has the potential to affect the conversion of other DLPack tensors negatively? Synchronization is always bad and best avoided. For example, should we perhaps only interpret |
|
Hi Wenzel, I agree that we should definitely avoid adding unnecessary synchronizations. Checking the documentation again, I see that:
So essentially, when passing
In DrJit, is the main default stream always used, even when launching kernels from non-main threads? In DrJit, what would be a reliable way to check if the current thread is the "main thread"? (= the thread for which the DrJit stream is the default stream) If we make the changes above, then I'm not sure which value of
Maybe |
|
Hi @wjakob, Could you please check the questions above? I'll update the PR based on your answers and it should fix the flaky TF unit tests. |
c821068 to
dcb5217
Compare
|
Hi Merlin, I just had a moment to take a look at this.
That's right. I looked again what Dr.Jit does, since this changed some time ago. Mitsuba creates a custom stream per device, which is set up with flag
diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h
index 720ca16..200d3cc 100644
--- a/include/drjit-core/jit.h
+++ b/include/drjit-core/jit.h
@@ -202,8 +202,11 @@ extern JIT_EXPORT void *jit_cuda_lookup(const char *name);
* \brief Add CUDA event synchronization between thread state's and external
* CUDA stream.
*
- * An event will be recorded into the thread's states stream and the external stream
- * will wait on the event before performing any subsequent work.
+ * An event will be recorded into the thread's states stream and the external
+ * stream will wait on the event before performing any subsequent work. The
+ * special value stream==2 denotes the caller's per-thread default stream.
+ * There is no need to ever synchronize with the global NULL stream, since
+ * Dr.Jit implicitly synchronizes with respect to it.
*
* \param stream The CUstream handle of the external stream
*/
diff --git a/src/cuda_api.cpp b/src/cuda_api.cpp
index 293f4de..b3d0f70 100644
--- a/src/cuda_api.cpp
+++ b/src/cuda_api.cpp
@@ -123,6 +123,7 @@ bool jitc_cuda_api_init() {
LOAD(cuStreamDestroy, "v2");
LOAD(cuStreamSynchronize);
LOAD(cuStreamWaitEvent);
+ LOAD(cuStreamWaitEvent_ptsz);
LOAD(cuPointerGetAttribute);
LOAD(cuArrayCreate, "v2");
LOAD(cuArray3DCreate, "v2");
@@ -174,7 +175,7 @@ void jitc_cuda_api_shutdown() {
Z(cuModuleGetFunction); Z(cuModuleLoadData); Z(cuModuleLoadDataEx); Z(cuModuleUnload);
Z(cuOccupancyMaxPotentialBlockSize); Z(cuCtxPushCurrent);
Z(cuCtxPopCurrent); Z(cuStreamCreate); Z(cuStreamDestroy);
- Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuPointerGetAttribute);
+ Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuStreamWaitEvent_ptsz); Z(cuPointerGetAttribute);
Z(cuArrayCreate); Z(cuArray3DCreate); Z(cuArray3DGetDescriptor);
Z(cuArrayDestroy); Z(cuTexObjectCreate); Z(cuTexObjectGetResourceDesc);
Z(cuTexObjectDestroy); Z(cuMemcpy2DAsync); Z(cuMemcpy3DAsync);
diff --git a/src/cuda_api.h b/src/cuda_api.h
index a2a04c0..142d9e5 100644
--- a/src/cuda_api.h
+++ b/src/cuda_api.h
@@ -260,6 +260,7 @@ DR_CUDA_SYM(CUresult (*cuStreamCreate)(CUstream *, unsigned int));
DR_CUDA_SYM(CUresult (*cuStreamDestroy)(CUstream));
DR_CUDA_SYM(CUresult (*cuStreamSynchronize)(CUstream));
DR_CUDA_SYM(CUresult (*cuStreamWaitEvent)(CUstream, CUevent, unsigned int));
+DR_CUDA_SYM(CUresult (*cuStreamWaitEvent_ptsz)(CUstream, CUevent, unsigned int));
DR_CUDA_SYM(CUresult (*cuMemAllocAsync)(CUdeviceptr *, size_t, CUstream));
DR_CUDA_SYM(CUresult (*cuMemFreeAsync)(CUdeviceptr, CUstream));
diff --git a/src/cuda_core.cpp b/src/cuda_core.cpp
index 1e54389..cbcbdda 100644
--- a/src/cuda_core.cpp
+++ b/src/cuda_core.cpp
@@ -109,8 +109,12 @@ std::pair<CUmodule, bool> jitc_cuda_compile(const char *buf, bool release_state_
void jitc_cuda_sync_stream(uintptr_t stream) {
ThreadState* ts = thread_state(JitBackend::CUDA);
CUevent sync_event = ts->sync_stream_event;
- cuda_check(cuEventRecord(sync_event, (CUstream)ts->stream));
- cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
+ scoped_set_context guard(ts->context);
+ cuda_check(cuEventRecord(sync_event, (CUstream) ts->stream));
+ if (stream != 2)
+ cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
+ else
+ cuda_check(cuStreamWaitEvent_ptsz(nullptr, sync_event, CU_EVENT_DEFAULT));
} |
61c30e2 to
6eaa0c3
Compare
|
Thank you @wjakob for looking into it and providing the patch! I've opened mitsuba-renderer/drjit-core#139 with your patch. @dvicini, do the changes in this PR make sense to you? In particular, calling |
|
@merlinND Is that so? I thought we now insert an event and wait for it asynchronously. That is assuming that TF uses the special |
|
Sorry, it was not very clear because two things were included in this PR:
Please let me know if I missed something. As you said, |
|
Ok, that makes sense. Let's wait for @dvicini's on the Google™ viewpoint before merging :-) |
|
Ping @dvicini, do you have an opinion above the above regarding TF interop? |
bf48ccb to
e3f5762
Compare
|
This PR was still waiting for feedback from @dvicini |
|
I completely missed those pings, sorry! I don't have a groundtruth answer unfortunately. I experimented with Tensorflow some time ago and had a version of a custom TF op (before the PR to Dr.Jit itself was made). At the time, I ran into a number of issues with TF's own threading system (entirely separate from CUDA streams) and in the end never had a good solution. I recall that I ran into issues with specific combinations of using Doing the conservative sync call makes sense to me and I am not surprised that TF does not expose those streams directly. We mostly use Jax now for our projects that use Mitsuba/Dr.jit. |
fc79b70 to
bdc17f6
Compare
This PR attempts to fix synchronization issues that come up in the unit tests of the new TF interop feature (#301): #301 (comment)
Since TensorFlow uses non-default CUDA streams for compute and data movement, I believe that we need to synchronize the stream used by DrJit before exporting a tensor to TF.