Skip to content

Extends drjit.wrap to TensorFlow#301

Merged
merlinND merged 3 commits intomitsuba-renderer:masterfrom
jhoydis:interop-tf
Mar 14, 2025
Merged

Extends drjit.wrap to TensorFlow#301
merlinND merged 3 commits intomitsuba-renderer:masterfrom
jhoydis:interop-tf

Conversation

@jhoydis
Copy link
Contributor

@jhoydis jhoydis commented Oct 17, 2024

The differentiable bridge between Dr.Jit (drjit.wrap) is currently limited to PyTorch and JAX.

This PR extends the support to TensorFlow in both directions, i.e., 'tf->drjit' and 'drjit->tf'.

The PR introduces a new wrapper for 'tf->drjit' in drjit.wrap() and adds new cases for 'drjit->tf' in drjit.WrapADop in forward and backward mode.

For 'drjit->tf' everything works like for 'drjit->torch' with a limitation related to tf.int32 input types (for which an issue was filed). Like for JAX, the TF code can be compiled using the @tf.function and @tf.function(jit_compile=True) decorators.

For 'tf->drjit', forward-mode AD is not fully supported due to limitations of TF functions with custom gradients. Such functions cannot have non-tensor input structures and we discovered and reported a bug for functions with keyword arguments.

The existing unit tests were extended to TensorFlow (in eager, graph and XLA modes).
The function skip_on in conftest.py was extended to allow for a custom message which can be different from the information provided by the raised exception.

The docstring of drjit.wrap was modified to account for the new extensions and related limitations described above.

@wjakob
Copy link
Member

wjakob commented Oct 17, 2024

Dear @jhoydis (cc @merlinND) -- this looks amazing! I would love to review this ASAP, but I am about to leave for a two-week vacation. I will take a thorough look when I am back. Regarding the issues reported to TensorFlow, I am wondering if it makes sense for some of them to be resolved before merging? (For example, the restriction to int32 sounds quite severe).

@merlinND
Copy link
Member

merlinND commented Oct 17, 2024

Hi @wjakob,

I will make a first review so that hopefully it's almost ready when you're back 👍

Regarding the int32 limitation in TF, it's actually unrelated to @dr.wrap() and it has been there forever AFAICT.
Basically, int32-typed TF tensors seem to always live on the CPU, even if their device field says it's on the GPU. So when importing the tensor via DLPack, TF says it's a CPU tensor and DrJit (correctly) creates an LLVM array.

import drjit as dr
import tensorflow as tf

with tf.device('gpu'):
    x_tf = tf.constant([0, 1, 2], tf.int32)

x_dr = dr.detail.import_tensor(x_tf, ad=False)
print("TF device:", x_tf.device)
print("DrJit type:", type(x_dr))

# TF device: /job:localhost/replica:0/task:0/device:GPU:0
# DrJit type: <class 'drjit.llvm.TensorXi'>

(reproducer from @jhoydis).

So I would say that we don't need to wait on a bugfix for this one, since it doesn't prevent AD interop at all and the current TF behavior is even part of a TF unit test.

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be a super useful feature, thanks @jhoydis!

These are mostly small comments, I tried to provide the suggested code / changes directly when possible to save time.

drjit/interop.py Outdated
if h.requires_grad:
dr.enable_grad(r)
if source == 'tf' and enable_grad:
dr.enable_grad(r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PyTorch case above, we check whether the tensor has grads enabled before enabling them on the DrJit side.
If there's really no way to do an equivalent check, I think it would be good to document for which cases grads will get enabled unconditionally (maybe in the @dr.wrap docstring which already describes the other special cases).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find a way to figure out programmatically if a TF tensor is watched by a GradientTape (unless it is a Variable). I slightly improved this case as follows:

if source == 'tf' and enable_grad:
    # There is no TF equivalent to h.requires_grad.
    # We hence enable gradients for all trainable variables
    # and tensors of floating dtype.
    if tf_var_check(h):
        if h.trainable:
            dr.enable_grad(r)
        elif h.dtype.is_floating:
            dr.enable_grad(r)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's correct, as far as I know there is no better way to get this information from TF.

drjit/interop.py Outdated

# Evaluate the function using another array programming framework
self.out = func(*self.args, **self.kwargs)
tensor = flatten(self.args)[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this fail if there are no array-typed args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rewritten this part. There is now a new helper function find_first_tf_tensor that finds the first TF tensor instance of a PyTree and returns it. TF will then use this tensor's device for all computations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rebased this PR to the master branch and also tested with TF2.18 locally.
What is missing?

@merlinND
Copy link
Member

The CI failure with GCC 9 is a bit mysterious. It seems to segfault after running test_wrap.py, but TensorFlow is not installed in that config (there's only PyTorch).
I tried reproducing locally, but everything is green on my machine.

@dvicini
Copy link
Member

dvicini commented Oct 21, 2024

I am not sure about the CI failure, but another high-level comment: Is it correct to assume that the wrapping of Dr.Jit into a TF computation will not support tf.function? If so, it would be good to mention this somewhere.

For what it's worth, I never had much luck trying to get tf.function to properly wrap around Dr.Jit and ran into some nasty race conditons / deadlocks (in TF itself) when trying to come up with some combination of tf.custom_gradient and tf.py_function to make this work.

@jhoydis
Copy link
Contributor Author

jhoydis commented Oct 21, 2024

I am not sure about the CI failure, but another high-level comment: Is it correct to assume that the wrapping of Dr.Jit into a TF computation will not support tf.function? If so, it would be good to mention this somewhere.

For what it's worth, I never had much luck trying to get tf.function to properly wrap around Dr.Jit and ran into some nasty race conditons / deadlocks (in TF itself) when trying to come up with some combination of tf.custom_gradient and tf.py_function to make this work.

I have never tried wrapping Dr.Jit into TF via tf.py_function. But I would also assume that it does not work properly. My assumption was that a user would wrap the TF code parts separately in a tf.function without enclosing the Dr.Jit part. I agree that we should add a comment about this issue. However, you probably have the same problem with other non-TF code that you want to wrap.

@merlinND
Copy link
Member

merlinND commented Jan 3, 2025

Regarding the CI:

  1. It hasn't run on Windows and macOS, maybe due to the new approval policy?
  2. IIUC, tensorflow is installed only for the `Linux (amd64, Ubuntu 24.04, GCC 13, LLVM 18) build. It would be nice if it was installed on the macOS and Windows variants as well.
  3. The CI fails for Ubuntu 20.04 + GCC 9, even though Tensorflow is not installed. I couldn't reproduce it locally :(

@njroussel
Copy link
Member

Just to follow-up on the CI:

  • I'm not sure how this approval mechanism works, I'll look into it.
  • Installing PyTorch, Tensorflow, Jax in the same environment is surprisingly (and sadly) difficult or impossible depending on the Python/CUDA version dependency matrix. That's why Tensorflow & Jax were only installed on one of the three linux environments. Without GPU support, that shouldn't be a problem anymore so I think we can install on MacOS without any issues. However, for Windows I think we'd be stuck with a CPU version there too.

@wjakob
Copy link
Member

wjakob commented Jan 28, 2025

Dear @jhoydis,

let me apologize in taking a really long time to get to this. I would like to review and swiftly merge this PR.
I installed TensorFlow on both the macOS and Windows CI machines.

Could I ask you to force-push this PR? That should register the PR with the build system. (It will still need to approve it, but right now I can't even do that since the changes are on your branch).

Thanks,
Wenzel

@jhoydis
Copy link
Contributor Author

jhoydis commented Jan 28, 2025

I have force-pushed the last commit without any changes. It seems to have triggered the build system. Let me know if anything else is required from my side.

@wjakob
Copy link
Member

wjakob commented Jan 28, 2025

I approved those builds, future ones should run directly (assuming I set this up correctly).

It looks like there are some failures -- you should be able to see the details by clicking the "guest login" button.

In any case, here is the macOS failure:

15:55:52   ==================================== ERRORS ====================================
15:55:52   __________________ ERROR collecting build/tests/test_wrap.py ___________________
15:55:52   tests/test_wrap.py:47: in <module>
15:55:52       import tensorflow as tf
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/__init__.py:37: in <module>
15:55:52       from tensorflow.python.tools import module_util as _module_util
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/__init__.py:42: in <module>
15:55:52       from tensorflow.python import data
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/__init__.py:21: in <module>
15:55:52       from tensorflow.python.data import experimental
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/experimental/__init__.py:95: in <module>
15:55:52       from tensorflow.python.data.experimental import service
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/experimental/service/__init__.py:387: in <module>
15:55:52       from tensorflow.python.data.experimental.ops.data_service_ops import distribute
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/experimental/ops/data_service_ops.py:22: in <module>
15:55:52       from tensorflow.python.data.experimental.ops import compression_ops
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/experimental/ops/compression_ops.py:16: in <module>
15:55:52       from tensorflow.python.data.util import structure
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/util/structure.py:22: in <module>
15:55:52       from tensorflow.python.data.util import nest
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/data/util/nest.py:36: in <module>
15:55:52       from tensorflow.python.framework import sparse_tensor as _sparse_tensor
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/framework/sparse_tensor.py:24: in <module>
15:55:52       from tensorflow.python.framework import constant_op
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/framework/constant_op.py:25: in <module>
15:55:52       from tensorflow.python.eager import execute
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/eager/execute.py:23: in <module>
15:55:52       from tensorflow.python.framework import dtypes
15:55:52   /Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/framework/dtypes.py:29: in <module>
15:55:52       _np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
15:55:52   E   TypeError: Unable to convert function return value to a Python type! The signature was
15:55:52   E     () -> handle
15:55:52   ------------------------------- Captured stderr --------------------------------
15:55:52   RuntimeError: module compiled against API version 0xf but this version of numpy is 0xe
15:55:52   RuntimeError: module compiled against API version 0xf but this version of numpy is 0xe
15:55:52   ImportError: numpy.core._multiarray_umath failed to import
15:55:52   ImportError: numpy.core.umath failed to import
15:55:52   =============================== warnings summary ===============================

The version of TensorFlow installed on the macOS build machine is: tensorflow-macos version 2.9.2.

On windows (with the latest PyPI version of tensorflow), it just segfaults directly, which is kind of concerning.

15:59:44   tests\test_conversion.py::test05_roundtrip_dynamic_tf[drjit.llvm.TensorXu] PASSED [ 19%]
15:59:44   Windows fatal exception: access violation
15:59:44   
15:59:44   Current thread 0x00002618 (most recent call first):
15:59:44     Garbage-collecting
15:59:44     File "C:\Program Files\Python311\Lib\ast.py", line 50 in parse
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\source.py", line 185 in getstatementrange_ast
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\code.py", line 263 in getsource
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\code.py", line 722 in _getentrysource
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\code.py", line 814 in repr_traceback_entry
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\code.py", line 871 in repr_traceback
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\code.py", line 944 in repr_excinfo
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\_code\code.py", line 669 in getrepr
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\nodes.py", line 484 in _repr_failure_py
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\python.py", line 1823 in repr_failure
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\reports.py", line 349 in from_item_and_call
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\runner.py", line 366 in pytest_runtest_makereport
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\pluggy\_callers.py", line 103 in _multicall
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\pluggy\_manager.py", line 120 in _hookexec
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\pluggy\_hooks.py", line 513 in __call__
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\runner.py", line 222 in call_and_report
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\runner.py", line 131 in runtestprotocol
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\_pytest\runner.py", line 112 in pytest_runtest_protocol
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\pluggy\_callers.py", line 103 in _multicall
15:59:44     File "C:\Program Files\Python311\Lib\site-packages\pluggy\_manager.py", line 120 in _hookexec

@wjakob
Copy link
Member

wjakob commented Jan 28, 2025

The macOS tensorflow install was non-functional. I fixed it and will restart the build.

@wjakob
Copy link
Member

wjakob commented Jan 28, 2025

On the amd64/ GCC 9 build, the test suite runs to completion but then segfaults during interpreter shutdown when Tensorflow was used. There are several warning messages of the type "SystemError: Py_AddPendingCall: cannot add pending calls (Python shutting down)". It sounds to me like some cleanup of tensors is delayed so much that it's happening at the same time as Python shutting down.

I restarted the macOS build.

@wjakob
Copy link
Member

wjakob commented Jan 28, 2025

The macOS CI machine can now properly import tensorflow. It segfaults while running the TensorFlow tests:

 build.tests.test_wrap.test01_simple_bwd(drjit_llvm_ad_Float64-config2-False)
16:17:26   PASSED [ 94%]
16:17:26   tests/test_wrap.py::test01_simple_bwd[drjit.llvm.ad.Float64-config3-True]
16:17:26   build.tests.test_wrap.test01_simple_bwd(drjit_llvm_ad_Float64-config3-True)
16:17:26   Fatal Python error: Aborted
16:17:26   
16:17:26   Current thread 0x0000000104e13d40 (most recent call first):
16:17:26     File "/Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/eager/execute.py", line 54 in quick_execute
16:17:26     File "/Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/eager/function.py", line 497 in call
16:17:26     File "/Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/eager/function.py", line 1868 in _call_flat
16:17:26     File "/Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/eager/def_function.py", line 986 in _call
16:17:26     File "/Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/eager/def_function.py", line 915 in __call__
16:17:26     File "/Users/ci/Library/Python/3.9/lib/python/site-packages/tensorflow/python/util/traceback_utils.py", line 150 in error_handler
16:17:26     File "/Users/ci/TeamCity/work/4f8e3e85cc244faa/build/drjit/interop.py", line 402 in eval
16:17:26     File "/Users/ci/TeamCity/work/4f8e3e85cc244faa/build/drjit/interop.py", line 743 in <lambda>
16:17:26     File "/Users/ci/TeamCity/work/4f8e3e85cc244faa/build/tests/test_wrap.py", line 119 in test01_simple_bwd
16:17:26     File "/Users/ci/TeamCity/work/4f8e3e85cc244faa/build/tests/conftest.py", line 85 in wrapper
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/python.py", line 183 in pytest_pyfunc_call
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/callers.py", line 187 in _multicall
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 84 in <lambda>
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 93 in _hookexec
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/hooks.py", line 286 in __call__
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/python.py", line 1641 in runtest
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 162 in pytest_runtest_call
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/callers.py", line 187 in _multicall
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 84 in <lambda>
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 93 in _hookexec
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/hooks.py", line 286 in __call__
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 255 in <lambda>
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 311 in from_call
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 254 in call_runtest_hook
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 215 in call_and_report
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 126 in runtestprotocol
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/callers.py", line 187 in _multicall
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 84 in <lambda>
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 93 in _hookexec
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/hooks.py", line 286 in __call__
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/main.py", line 348 in pytest_runtestloop
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/callers.py", line 187 in _multicall
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 84 in <lambda>
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 93 in _hookexec
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/hooks.py", line 286 in __call__
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/main.py", line 323 in _main
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/main.py", line 269 in wrap_session
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/main.py", line 316 in pytest_cmdline_main
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/callers.py", line 187 in _multicall
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 84 in <lambda>
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/manager.py", line 93 in _hookexec
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pluggy/hooks.py", line 286 in __call__
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/config/__init__.py", line 162 in main
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/_pytest/config/__init__.py", line 185 in console_main
16:17:26     File "/opt/homebrew/lib/python3.9/site-packages/pytest/__main__.py", line 5 in <module>
16:17:26     File "/opt/homebrew/Cellar/[email protected]/3.9.5/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 87 in _run_code
16:17:26     File "/opt/homebrew/Cellar/[email protected]/3.9.5/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 197 in _run_module_as_main
16:17:26   /Users/ci/TeamCity/temp/agentTmp/custom_script9641381054698255468: line 2: 36023 Abort trap: 6           python3 -m pytest ./tests/
16:17:26   Process exited with code 134
16:17:26 

@merlinND
Copy link
Member

On the amd64/ GCC 9 build, the test suite runs to completion but then segfaults during interpreter shutdown when Tensorflow was used.

Before TensorFlow was installed, it also segfaulted after running test_warp.py: #301 (comment)
I couldn't reproduce it locally, with or without TF installed :(

@wjakob
Copy link
Member

wjakob commented Jan 28, 2025

What platform (OS, Compiler) did you use for testing? Did you try running on one other OS besides that? (That's usually a good way of flushing out those kinds of issues)

@merlinND
Copy link
Member

Before TensorFlow was installed on more CI targets, the GCC9 CI build was already segfaulting.
I tried reproducing the crash locally on amd64, Ubuntu 24.04, GCC 9 but all tests were passing. Likewise for Ubuntu 24.04 + Clang 19, all green on my machine.

Regarding the macOS and Windows failures, I don't think these platforms had been tested yet (although maybe @jhoydis tested on macOS already?).

@jhoydis
Copy link
Contributor Author

jhoydis commented Jan 28, 2025

I had tested it for the initial PR on macOS but will do it again now.

@jhoydis
Copy link
Contributor Author

jhoydis commented Jan 28, 2025

I have just run all tests successfully on my Mac:

  • Apple M1 / Sequoia 15.1.1
  • AppleClang 16.0.0.16000026
  • TF 2.18.0
  • Python 3.12.3

@wjakob
Copy link
Member

wjakob commented Jan 29, 2025

I just tried running this on a Linux machine and ran into a crash involving an exception raised by Tensorflow (within pybind11 code). I believe what happens here is that libstdc++ traverses the call stack containing Clang-compiled functions using libc++ exceptions. There are some ABI incompatibilities here that have tripped us up in the past. I got things to pass when compiling with GCC, which is curiously the opposite of what the CI complained about.

I believe that we should skip the TF tests on Linux (only Linux, that issue isn't present on macOS), if drjit.config.CXX_COMPILER
contains the substring "Clang".

I was able to run things sucessfully on macOS (TF ver 2.16.2) on my Laptop. In contrast, the CI box was running a much older TF version (2.9.2) since that was the newest one available on Python 3.9. I suspect that this is the cause of the crashes we've been seeing. I will look into upgrading Python. But that brings a related question: what's the minimum TF version required for the AD interop to work?

While testing on Linux, I noticed two issues with memory leaks. Both test24 and test30 result in a number of instance leaks (easy to verify by directly running pytest tests/test_wrap_ad.py -k test..). Do you also see this on your end?

@wjakob
Copy link
Member

wjakob commented Jan 29, 2025

Update: the reason why I can't install a more recent version on the macOS CI runner is that the version of macOS is too old. This will take some time to fix (likely not until my return to EPFL). So I've uninstalled TF from this machine.

@wjakob
Copy link
Member

wjakob commented Jan 29, 2025

What the hell. Apparently tensorflow does not support the use of GPUs on Windows since v2.10 (only through WSL2, i.e., the linux version). I believe that the crashes here occur from trying to run CUDA tests together with a Tensorflow build that cannot deal with GPU tensors.

@wjakob
Copy link
Member

wjakob commented Jan 31, 2025

@jhoydis, following the above observations, could you look into the following?

  • Address memory leaks caused by test24 and test30
  • Error out when the user attempts to use TensorFlow with CUDA on Windows, skip related tests.
  • Skip tests on macOS for now.

@merlinND
Copy link
Member

merlinND commented Feb 3, 2025

Address memory leaks caused by test24 and test30

I spent quite some hours investigating the leaks from test24. I could narrow it down to the following conditions:

  • configs include at least one eager-mode and one XLA-mode
  • LLVM backend (probably because in that case, the DLPack-based tensor retains a reference to the DrJit array's storage)
  • There must be at least one call to dr.forward_to() (dr.backward_from() doesn't trigger the leak).
  • dr.llvm.ad.Float is enough, no need for Array3f to trigger the bug

The variables that get leaked are the function inputs of the XLA config of the test. More specifically, the dr.TensorXf variables that wrap the original DrJit function inputs.

If there's no eager-mode config enabled, then the variables get correctly cleaned up at the very end once the tensorflow::DeviceCompilationCache gets deleted, presumably upon module unloading:

Stack trace
#0  jitc_var_free (index=1, v=0x62a47c0) at /drjit/ext/drjit-core/src/var.cpp:279
#1  0x00007ffff4452187 in jitc_var_dec_ref (index=1, v=0x62a47c0) at /drjit/ext/drjit-core/src/var.cpp:475
#2  0x00007ffff44521bb in jitc_var_dec_ref (index=1) at /drjit/ext/drjit-core/src/var.cpp:482
#3  0x00007ffff45d3b7d in jit_var_dec_ref_impl (index=1) at /drjit/ext/drjit-core/src/api.cpp:589
#4  0x00007ffff479359a in jit_var_dec_ref (index=1) at /drjit/ext/drjit-core/include/drjit-core/jit.h:1103
#5  ad_var_dec_ref_impl (index=1) at /drjit/src/extra/autodiff.cpp:756
#6  0x00007ffff525fc09 in ad_var_dec_ref (index=1) at /drjit/include/drjit/extra.h:509
#7  drjit::DiffArray<(JitBackend)2, float>::~DiffArray (this=0x7fff1e00a808, __in_chrg=<optimized out>) at /drjit/include/drjit/autodiff.h:120
#8  0x00007ffff52e506a in drjit::Tensor<drjit::DiffArray<(JitBackend)2, float> >::~Tensor (this=0x7fff1e00a808, __in_chrg=<optimized out>) at /drjit/include/drjit/tensor.h:111
#9  0x00007ffff5450e1a in nanobind::detail::wrap_destruct<drjit::Tensor<drjit::DiffArray<(JitBackend)2, float> > > (value=0x7fff1e00a808) at /drjit/ext/nanobind/include/nanobind/nb_class.h:245
#10 0x00007ffff5e77c3d in nanobind::detail::inst_dealloc (self=0x7fff1e00a7f0) at /drjit/ext/nanobind/src/nb_type.cpp:255
#11 0x00007ffff5e780be in Py_DECREF (op=0x7fff1e00a7f0) at /usr/include/python3.12/object.h:705
#12 nanobind::detail::inst_dealloc (self=0x7fff1e2572f0) at /drjit/ext/nanobind/src/nb_type.cpp:336
#13 0x00007ffff5e8b1b9 in Py_DECREF (op=0x7fff1e2572f0) at /usr/include/python3.12/object.h:705
#14 Py_XDECREF (op=0x7fff1e2572f0) at /usr/include/python3.12/object.h:798
#15 0x00007ffff5e8df99 in nanobind::detail::ndarray_dec_ref (th=0x7fff1e1f3e70) at /drjit/ext/nanobind/src/nb_ndarray.cpp:623
#16 0x00007ffff5e8e0b1 in operator() (__closure=0x0, mt=0x7fff1e00a7b0) at /drjit/ext/nanobind/src/nb_ndarray.cpp:667
#17 0x00007ffff5e8e0f5 in _FUN () at /drjit/ext/nanobind/src/nb_ndarray.cpp:668
#18 0x00007fffae7cb896 in TF_ManagedBuffer::~TF_ManagedBuffer() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#19 0x00007fffaed6651c in tensorflow::Tensor::~Tensor() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#20 0x00007fff9633535a in absl::lts_20230802::inlined_vector_internal::Storage<std::variant<tensorflow::Tensor, std::pair<tensorflow::DataType, absl::lts_20230802::InlinedVector<long, 4ul, std::allocator<long> > > >, 8ul, std::allocator<std::variant<tensorflow::Tensor, std::pair<tensorflow::DataType, absl::lts_20230802::InlinedVector<long, 4ul, std::allocator<long> > > > > >::DestroyContents() ()
   from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2
#21 0x00007fff9652b803 in absl::lts_20230802::container_internal::raw_hash_set<absl::lts_20230802::container_internal::FlatHashMapPolicy<tensorflow::DeviceCompilationClusterSignature, std::unique_ptr<tensorflow::DeviceCompilationCache<xla::LocalExecutable>::Entry, std::default_delete<tensorflow::DeviceCompilationCache<xla::LocalExecutable>::Entry> > >, tensorflow::DeviceCompilationClusterSignature::Hash, std::equal_to<tensorflow::DeviceCompilationClusterSignature>, std::allocator<std::pair<tensorflow::DeviceCompilationClusterSignature const, std::unique_ptr<tensorflow::DeviceCompilationCache<xla::LocalExecutable>::Entry, std::default_delete<tensorflow::DeviceCompilationCache<xla::LocalExecutable>::Entry> > > > >::~raw_hash_set() ()
   from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2
#22 0x00007fff9652b68f in tensorflow::DeviceCompiler<xla::LocalExecutable, xla::LocalClient>::~DeviceCompiler() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2
#23 0x00007fff9652b73e in tensorflow::DeviceCompiler<xla::LocalExecutable, xla::LocalClient>::~DeviceCompiler() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2
#24 0x00007fffaeb96f3b in std::__detail::__variant::__gen_vtable_impl<true, std::__detail::__variant::_Multi_array<std::__detail::__variant::__variant_cookie (*)(std::__detail::__variant::_Variant_storage<false, tsl::core::RefCountPtr<tensorflow::ResourceBase>, tsl::core::WeakPtr<tensorflow::ResourceBase> >::_M_reset_impl()::{lambda(auto:1&&)#1}&&, std::variant<tsl::core::RefCountPtr<tensorflow::ResourceBase>, tsl::core::WeakPtr<tensorflow::ResourceBase> >&)>, std::tuple<std::variant<tsl::core::RefCountPtr<tensorflow::ResourceBase>, tsl::core::WeakPtr<tensorflow::ResourceBase> >&>, std::integer_sequence<unsigned long, 0ul> >::__visit_invoke(std::__detail::__variant::_Variant_storage<false, tsl::core::RefCountPtr<tensorflow::ResourceBase>, tsl::core::WeakPtr<tensorflow::ResourceBase> >::_M_reset_impl()::{lambda(auto:1&&)#1}&&, std::variant<tsl::core::RefCountPtr<tensorflow::ResourceBase>, tsl::core::WeakPtr<tensorflow::ResourceBase> >&) () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#25 0x00007fffaeb9295a in tensorflow::ResourceMgr::ResourceAndName::~ResourceAndName() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#26 0x00007fffaeb92dbc in tensorflow::ResourceMgr::Clear() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#27 0x00007fffaeaa4d30 in tensorflow::DynamicDeviceMgr::~DynamicDeviceMgr() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#28 0x00007fffaeaa4ece in tensorflow::DynamicDeviceMgr::~DynamicDeviceMgr() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_framework.so.2
#29 0x00007fff96d26da3 in tensorflow::EagerContext::~EagerContext() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2
#30 0x00007fff96d26fa5 in non-virtual thunk to tensorflow::EagerContext::~EagerContext() () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2
#31 0x00007fffd4421c5a in TFE_DeleteContextCapsule(_object*) () from /drjit/build-Debug/.venv/lib/python3.12/site-packages/tensorflow/python/platform/../_pywrap_tensorflow_internal.so
#32 0x000000000054c5c0 in ?? ()
#33 0x000000000057627e in ?? ()
#34 0x0000000000575fcc in ?? ()
#35 0x000000000059f2d5 in ?? ()
#36 0x0000000000573736 in ?? ()
#37 0x0000000000583664 in _PyModule_Clear ()
#38 0x00000000006b1a19 in ?? ()
#39 0x00000000006b0e9d in Py_FinalizeEx ()
#40 0x00000000006bc851 in Py_RunMain ()
#41 0x00000000006bc46d in Py_BytesMain ()
#42 0x00007ffff7c2a1ca in __libc_start_call_main (main=main@entry=0x518850, argc=argc@entry=7, argv=argv@entry=0x7fffffffd468) at ../sysdeps/nptl/libc_start_call_main.h:58
#43 0x00007ffff7c2a28b in __libc_start_main_impl (main=0x518850, argc=7, argv=0x7fffffffd468, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffd458) at ../csu/libc-start.c:360
#44 0x0000000000657c15 in _start ()

When there's both an Eager and an XLA config running, tensorflow::DeviceCompiler<xla::LocalExecutable, xla::LocalClient>::~DeviceCompiler is not invoked during cleanup for some reason.

I've tried manually clearing various caches I could access from the XLA-compiled function, but no luck:

        test_fn.func._function_cache.clear()
        test_fn.func._function_captures.clear()
        test_fn.func._python_function = None

TensorFlow users have reported XLA leaking memory before: tensorflow/tensorflow#80753, tensorflow/tensorflow#73457

@merlinND
Copy link
Member

merlinND commented Feb 4, 2025

Error out when the user attempts to use TensorFlow with CUDA on Windows, skip related tests

I've pushed a change to skip TF tests in cases where DrJit supports CUDA, but TF doesn't detect a GPU device. It should be slightly more general than detecting Windows + TF.

@merlinND
Copy link
Member

merlinND commented Feb 6, 2025

I've rebased, and added:

  • Skip TF tests on macOS because the CI is not ready yet for it
  • Add gc.collect() during teardown of test_wrap.py to avoid problems during shutdown

@merlinND merlinND force-pushed the interop-tf branch 4 times, most recently from 1af146f to 7298a9f Compare March 10, 2025 09:41
jhoydis and others added 3 commits March 14, 2025 11:00
Both forward and backward propagation (with some limitations).
+ unit tests
Specifically, the case where DrJit supports CUDA but TensorFlow doesn't, e.g. on native Windows.
+ explicitly call the garbage collector to avoid some potential TF-side leaks.
@merlinND merlinND merged commit 079f92d into mitsuba-renderer:master Mar 14, 2025
5 checks passed
@njroussel
Copy link
Member

Has anyone been using this? Is seems really unreliable. Out of 50 runs of test_wrap.py, it failed 23 times on my end. It's usually one of test23_nested_arrays_bwd or test05_simple_multiarg_bwd (with a tensorflow config, of course).

@wjakob
Copy link
Member

wjakob commented Apr 3, 2025

@njroussel Is your tensorflow version up-to-date?

@njroussel
Copy link
Member

Yes, it is. I initially noticed this in a CI run (link).

@merlinND
Copy link
Member

merlinND commented Apr 3, 2025

I will try to debug the flaky tests on my machine.

Edit: I can reproduce the test failures locally. Interestingly, tests do no fail when running with CUDA_LAUNCH_BLOCKING=1, so I am looking for some kind of synchronization issue.

@njroussel
Copy link
Member

@merlinND I think you're on the right track. I only briefly tried to debug it yesterday, but noticed that some carefully placed prints made the issue impossible to reproduce. We must have some sort of a race condition right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants