diff --git a/drjit/interop.py b/drjit/interop.py index bd8c5c54c..85b67a901 100644 --- a/drjit/interop.py +++ b/drjit/interop.py @@ -8,6 +8,7 @@ def pytorch_check(value, /): '''Returns ``True`` if ``value`` is a PyTorch tensor''' return type(value).__module__ == 'torch' and type(value).__name__ == 'Tensor' + def pytorch_fp_check(value, /): '''Returns ``True`` if ``value`` is a PyTorch floating point tensor''' return type(value).__module__ == 'torch' and type(value).__name__ == 'Tensor' and value.dtype.is_floating_point @@ -17,12 +18,29 @@ def jax_check(value, /): '''Returns ``True`` if ``value`` is a JAX tensor''' return type(value).__module__.startswith('jaxlib') + +def tf_check(value, /): + '''Returns ``True`` if ``value`` is a TensorFlow tensor''' + return type(value).__module__.startswith('tensorflow') + + +def tf_var_check(value, /): + '''Returns ``True`` if ``value`` is a TensorFlow variable''' + return type(value).__name__ == 'ResourceVariable' + + +def tf_fp_check(value): + '''Returns True if value is a TensorFlow floating point tensor''' + return tf_check(value) and value.dtype.is_floating + + def pytree_check(value, /): '''Returns ``True`` if ``value`` is a structural element of a PyTree''' tp = type(value) return tp is list or tp is tuple or \ tp is dict or getattr(tp, 'DRJIT_STRUCT', None) is not None + def apply(fn, a, /): '''Helper function to recursively map a PyTree through the function ``fn``''' tp = type(a) @@ -79,6 +97,7 @@ def apply2(fn, a, b, /): else: return a + def from_drjit(value, target, enable_grad = False, /): ''' Convert a PyTree containing Dr.Jit arrays/tensors to another array @@ -122,15 +141,28 @@ def fn(h, /): nonlocal tp_index tp = value_tp[tp_index] if value_tp is not None else None tp_index += 1 - if (source == 'torch' and pytorch_check(h)) or \ - (source == 'jax' and jax_check(h)): - r = dr.detail.import_tensor(h, True) + (source == 'jax' and jax_check(h)) or \ + (source == 'tf' and tf_check(h)): + + if source == 'tf' and tf_var_check(h) : + r = dr.detail.import_tensor(h.value(), True) + else: + r = dr.detail.import_tensor(h, True) if type(r) is not tp and dr.is_array_v(tp): r = tp(r) if source == 'torch' and enable_grad: if h.requires_grad: dr.enable_grad(r) + if source == 'tf' and enable_grad: + # There is no TF equivalent to h.requires_grad. + # We hence enable gradients for all trainable variables + # and tensors of floating dtype. + if tf_var_check(h): + if h.trainable: + dr.enable_grad(r) + elif h.dtype.is_floating: + dr.enable_grad(r) return r return ... @@ -151,6 +183,20 @@ def fn(h, /): return result +def tf_filter_fp(value, /): + '''Extract a flat list of floating point TensorFlow tensors from the PyTree ``value``''' + + result = [] + + def fn(h, /): + if tf_fp_check(h): + result.append(h) + return ... + + apply(fn, value) + return result + + def pytorch_grad(value, /): '''Extract a the gradients of PyTorch tensors from the PyTree ``value``''' @@ -195,7 +241,7 @@ def fn(a, b): return apply2(fn, a, b) -def fixup_grad(a, b, target, /): +def fixup_grad(a, b, target, /, tf_add_zero = False): ''' Fix up gradients so that they are accepted by routines like ``jax.vjp``, ``torch.autograd.backward``, etc. @@ -208,6 +254,9 @@ def fixup_grad(a, b, target, /): - replaces gradients for non-differentiable arrays with special objects that JAX expects. ''' + if (target == 'tf') and not tf_add_zero: + # Nothing to do, let's save the trouble + return a def fn(a, b): # Ignore structural PyTree elements @@ -216,6 +265,7 @@ def fn(a, b): is_jax = target == 'jax' and jax_check(a) is_torch = target == 'torch' and pytorch_check(a) + is_tf_fp = target == 'tf' and tf_add_zero and tf_fp_check(a) # JAX really doesn't like receiving gradients/tangents for non-diff. # elements. It wants a special array with dtype `jax.float0`. Such @@ -227,6 +277,11 @@ def fn(a, b): import numpy return numpy.zeros(getattr(b, 'shape', ()), dtype=jax.float0) + if is_tf_fp and tf_add_zero: + import tensorflow as tf + with tf.device(a.device): + return tf.zeros_like(a) + a + if type(a) is type(b): if is_jax or (is_torch and a.dtype.is_floating_point): return a.reshape(b.shape) @@ -292,10 +347,51 @@ def unflatten(desc, *flat): list(reversed(desc))) +def wrap_into_dr_tensor(value): + '''Helper to transform a PyTree's members to Dr.Jit tensors''' + def fn(h): + tp = type(h) + if dr.is_array_v(tp): + if not dr.is_tensor_v(h): + h = dr.tensor_t(tp)(h) + return h + return ... + return apply(fn, value) + + +def wrap_into_tf_tensor(value): + '''Helper to transform a PyTree's members to TF tensors''' + import tensorflow as tf + def fn(h): + try: + return tf.convert_to_tensor(h) + except ValueError: + return tf.constant(-1, tf.float32) + return tf.nest.map_structure(fn, value) + + +def find_first_tf_tensor(value): + '''Finds the first TensorFlow tensor in a PyTree''' + if isinstance(value, (list, tuple)): + for item in value: + result = find_first_tf_tensor(item) + if tf_check(result): + return result + elif isinstance(value, dict): + for key in value: + result = find_first_tf_tensor(value[key]) + if tf_check(result): + return result + else: + if tf_check(value): + return value + return None + + class WrapADOp(dr.CustomOp): ''' Dr.Jit custom operation that wraps differentiable computation performed - using another AD framework (e.g., PyTorch) + using another AD framework (e.g., PyTorch, TensorFlow) ''' def eval(self, func, target, *args, **kwargs): # Convert input PyTrees from Dr.Jit @@ -305,7 +401,17 @@ def eval(self, func, target, *args, **kwargs): self.func = func # Evaluate the function using another array programming framework - self.out = func(*self.args, **self.kwargs) + if target == 'tf': + import tensorflow as tf + self.watched_vars = wrap_into_tf_tensor([self.args, self.kwargs]) + self.device = find_first_tf_tensor([self.args, self.kwargs]).device + with tf.device(self.device): # Ensure that TF runs on the correct device + with tf.GradientTape() as tape: + tape.watch(self.watched_vars) + self.out = func(*self.args, **self.kwargs) + self.tape = tape + else: + self.out = func(*self.args, **self.kwargs) # Convert the out PyTree to Dr.Jit return to_drjit(self.out, target) @@ -317,7 +423,6 @@ def forward(self): grad_kwargs, _ = from_drjit(self.grad_in('kwargs'), target) grad_args = fixup_grad(grad_args, self.args, target) grad_kwargs = fixup_grad(grad_kwargs, self.kwargs, target) - if target == 'torch': import torch.autograd.forward_ad as fa @@ -335,22 +440,31 @@ def wrapper(args, kwargs): _, grad_out = jax.jvp( wrapper, (self.args, self.kwargs), (grad_args, grad_kwargs) ) + elif target == 'tf': + import tensorflow as tf + primals = list(self.args) + list(self.kwargs.values()) + tangents = list(grad_args) + list(grad_kwargs.values()) + with tf.device(self.device): + with tf.autodiff.ForwardAccumulator( + primals=tf_filter_fp(primals), + tangents=tf_filter_fp(tangents) + ) as acc: + out = self.func(*self.args, **self.kwargs) + grad_out = acc.jvp(wrap_into_tf_tensor(out)) else: raise RuntimeError('WrapADOp.forward(): unsupported framework!') - self.set_grad_out(to_drjit(grad_out, target)) def backward(self): target = self.target - grad_out, _ = from_drjit(self.grad_out(), target) - grad_out = fixup_grad(grad_out, self.out, target) + # Sever link to DrJit (via DLPack `owner` field) to avoid leaks (TF only). + grad_out = fixup_grad(grad_out, self.out, target, tf_add_zero=True) if target == 'torch': import torch torch.autograd.backward(pytorch_filter_fp(self.out), pytorch_filter_fp(grad_out)) - grad_args = pytorch_grad(self.args) grad_kwargs = pytorch_grad(self.kwargs) elif target == 'jax': @@ -359,11 +473,17 @@ def backward(self): def wrapper(args, kwargs): return self.func(*args, **kwargs) - primals, vjp_fun = jax.vjp(wrapper, self.args, self.kwargs) + _, vjp_fun = jax.vjp(wrapper, self.args, self.kwargs) grad_args, grad_kwargs = vjp_fun(grad_out) + elif target == 'tf': + import tensorflow as tf + with tf.device(self.device): # Ensure that TF runs on the correct device + out = wrap_into_tf_tensor(self.out) + + grad_args, grad_kwargs = self.tape.gradient(out, self.watched_vars, + output_gradients=grad_out) else: raise RuntimeError('WrapADOp.backward(): unsupported framework!') - self.set_grad_in('args', to_drjit(grad_args, target, self.args_tp)) self.set_grad_in('kwargs', to_drjit(grad_kwargs, target, self.kwargs_tp)) @@ -392,24 +512,13 @@ def forward(ctx, func, desc, *inputs): inputs = to_drjit(inputs, 'torch', enable_grad=True) args, kwargs = unflatten(desc, *inputs) - def wrap_into_tensor(value): - '''Helper to transform a PyTree's members to tensors''' - def fn(h): - tp = type(h) - if dr.is_array_v(tp): - if not dr.is_tensor_v(h): - h = dr.tensor_t(tp)(h) - return h - return ... - return apply(fn, value) - # Run the function, flatten the output PyTree and convert its members to tensors global torch_desc_o with torch_set_grad_enabled(True): # Torch autograd tracing is disabled within `Function.forward` # we turn it back on here in case func itself uses torch torch_desc_o, *output = flatten(func(*args, **kwargs)) - output = wrap_into_tensor(output) + output = wrap_into_dr_tensor(output) # Stash inputs and outputs for later use ctx.inputs, ctx.output = inputs, output @@ -478,8 +587,9 @@ def wrap(source: typing.Union[str, types.ModuleType], This function wraps computation performed using one array programming framework to expose it in another. Currently, `PyTorch - `__ and `JAX `__ are - supported, though other frameworks may be added in the future. + `__, `TensorFlow `__, + and `JAX `__ are supported, though other + frameworks may be added in the future. Annotating a function with :py:func:`@drjit.wrap ` adds code that suitably converts arguments and return values. Furthermore, it @@ -533,6 +643,44 @@ def wrap(source: typing.Union[str, types.ModuleType], An `issue `__ was filed on the PyTorch bugtracker. + * - ``drjit`` → ``tf`` + - .. centered:: ✅ + - .. centered:: ✅ + - .. centered:: ✅ + - You may want to further annotate the wrapped function with + ``tf.function`` to trace and just-in-time compile it in the + Tensorflow environment, i.e., + + .. code-block:: python + + @dr.wrap(source='drjit', target='tf') + @tf.function(jit_compile=False) # Set to True for XLA mode + + **Limitation**: There is an issue for tf.int32 tensors which are + wrongly placed on CPU by DLPack. This can lead to inconsistent device + placement of tensors. + + An `issue `__ + was filed on the TensorFlow bugtracker. + + * - ``tf`` → ``drjit`` + - .. centered:: ✅ + - .. centered:: ❌ + - .. centered:: ✅ + - TensorFlow has some limitiations with respect to custom gradients + in foward-mode AD. + + **Limitation**: TensorFlow does not allow for non-tensor + input structures in fuctions with + `custom gradients + `__. + + TensorFlow has a bug for functions with custom gradients and + keyword arguments. + + An `issue `__ + was filed on the TensorFlow bugtracker. + * - ``drjit`` → ``jax`` - .. centered:: ✅ - .. centered:: ✅ @@ -574,6 +722,11 @@ def wrap(source: typing.Union[str, types.ModuleType], ``uint32``, or ``uint64``-typed arrays). Use signed integer types to work around this issue. + - TensorFlow has limitations with respect to forward-mode AD for + functions with custom gradients.There is also an `issue for functions + with keyword arguments + `__. + - Dr.Jit currently lacks support for most 8- and 16-bit numeric types (besides half precision floats). @@ -584,12 +737,12 @@ def wrap(source: typing.Union[str, types.ModuleType], Args: source (str | module): The framework used *outside* of the wrapped function. The argument is currently limited to either ``'drjit'``, - ``'torch'``, or ``jax'``. For convenience, the associated Python + ``'torch'``, ``'tf'``, or ``jax'``. For convenience, the associated Python module can be specified as well. target (str | module): The framework used *inside* of the wrapped function. The argument is currently limited to either ``'drjit'``, - ``'torch'``, or ``'jax'``. For convenience, the associated Python + ``'torch'``, ``'tf'``, or ``'jax'``. For convenience, the associated Python module can be specified as well. Returns: @@ -599,17 +752,17 @@ def wrap(source: typing.Union[str, types.ModuleType], # Get module names if source and target are not already strings source = source.__name__ if not isinstance(source, str) else source target = target.__name__ if not isinstance(target, str) else target - valid_types = ('drjit', 'torch', 'jax') + valid_types = ('drjit', 'torch', 'jax', 'tf') if source not in valid_types: - raise Exception("drjit.wrap(): unknown 'source' argument.") + raise ValueError("drjit.wrap(): unknown 'source' argument.") if target not in valid_types: - raise Exception("drjit.wrap(): unknown 'target' argument.") + raise ValueError("drjit.wrap(): unknown 'target' argument.") if source != 'drjit' and target != 'drjit': - raise Exception("drjit.wrap(): at least one of 'source' and " - "'target' must equal \"drjit\".") + raise ValueError("drjit.wrap(): at least one of 'source' and " + "'target' must equal \"drjit\".") if source == target: # Nothing to do @@ -637,7 +790,36 @@ def wrapper_2(*args, **kwargs): return wrapper_2 return wrapper - else: - raise Exception("drjit.wrap(): unsupported combination of 'source' and 'target'.") - return None + elif target == 'drjit' and source == 'tf': + + import tensorflow as tf + + def wrapper(func): + + @tf.custom_gradient + def wrapper_2(*args, **kwargs): + inputs = to_drjit((args, kwargs), 'tf', enable_grad=True) + outputs = func(*inputs[0], **inputs[1]) + def grad(*dy): + if kwargs: + raise TypeError("Keyword arguments are not allowed for 'tf->drjit'") + grad_outputs = to_drjit(dy, 'tf') + out = flatten(outputs)[1:] + out = wrap_into_dr_tensor(out) + dr.set_grad(out, grad_outputs) + vars = flatten(inputs[0])[1:] # Only gradients for args are computed due + # to a TF bug https://github.com/tensorflow/tensorflow/issues/77559 + grads = dr.backward_to(vars) + grads = from_drjit(grads, 'tf')[0] + # Set gradients for non-differentiable tensors to None + grads = [(g if dr.grad_enabled(vars[i]) else None) \ + for i, g in enumerate(flatten(grads)[1:])] + return grads + return from_drjit(outputs, 'tf')[0], grad + + return wrapper_2 + + return wrapper + else: + raise ValueError("drjit.wrap(): unsupported combination of 'source' and 'target'.") diff --git a/tests/conftest.py b/tests/conftest.py index ab438ff6f..2c220d3f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,7 +76,7 @@ def wrapped(func): return wrapped -def skip_on(exception, reason): +def skip_on(exception, reason, msg=None): from functools import wraps def wrapped(func): @wraps(func) @@ -87,9 +87,9 @@ def wrapper(*args, **kwargs): m = str(e) c = str(e.__cause__) if reason in m: - pytest.skip(m) + pytest.skip(msg if msg is not None else m) elif reason in c: - pytest.skip(c) + pytest.skip(msg if msg is not None else c) else: raise e diff --git a/tests/test_conversion.py b/tests/test_conversion.py index c8f9a40fa..d707f352e 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -1,6 +1,17 @@ import drjit as dr import pytest + +def skip_tf_if_not_available(t): + # Skip overall if TF is not available + pytest.importorskip("tensorflow.config") + # Skip CUDA backend roundtrip if TensorFlow doesn't support + # CUDA, e.g. on native Windows since version 2.11. + from tensorflow.config import list_physical_devices + if (dr.backend_v(t) == dr.JitBackend.CUDA) and not list_physical_devices("GPU"): + pytest.skip("TensorFlow didn't detect a CUDA device, skipping.") + + # Test conversions to/from numpy (tensors & dynamic arrays) @pytest.test_arrays('is_tensor, -bool, -float16') def test01_roundtrip_dynamic_numpy(t): @@ -52,12 +63,7 @@ def test04_roundtrip_vector_torch(t): # Test conversions to/from tf (tensors & dynamic array) @pytest.test_arrays('tensor, -bool, -float16') def test05_roundtrip_dynamic_tf(t): - pytest.importorskip("tensorflow.config") - - import sys - if sys.platform == 'win32' and dr.backend_v(t) == dr.JitBackend.CUDA: - pytest.skip('Skipping TensorFlow GPU test on Windows') - + skip_tf_if_not_available(t) a = t([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) roundtrip = t(a.tf()) @@ -71,12 +77,7 @@ def test05_roundtrip_dynamic_tf(t): # Test conversions to/from tf (vectors) @pytest.test_arrays('vector, shape=(3, *), -bool, -float16') def test06_roundtrip_vector_tf(t): - pytest.importorskip("tensorflow.config") - - import sys - if sys.platform == 'win32' and dr.backend_v(t) == dr.JitBackend.CUDA: - pytest.skip('Skipping TensorFlow GPU test on Windows') - + skip_tf_if_not_available(t) a = t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) roundtrip = t(a.tf()) diff --git a/tests/test_wrap.py b/tests/test_wrap.py index 3fd90d8d4..ecfdf3e14 100644 --- a/tests/test_wrap.py +++ b/tests/test_wrap.py @@ -1,9 +1,14 @@ +import platform +import warnings + import drjit as dr import pytest -import warnings + configs_jax = [] configs_torch = [] +configs_tf = [] +configs_tf_jit = [] try: # Ignore deprecation warnings generated by the PyTorch package @@ -24,6 +29,7 @@ if torch.__version__ < torch.torch_version.TorchVersion('2.1.3'): supports_bool = False configs_torch.append(('torch', supports_bool, False)) + del supports_bool except ImportError: pass @@ -38,15 +44,40 @@ configs_jax.append(('jax', supports_bool, False)) configs_jax.append(('jax', supports_bool, True)) jit = jax.jit + del supports_bool except ImportError: pass -configs = configs_torch + configs_jax + +# TensorFlow is not setup correctly on the macOS CI yet. +# TODO: re-enable TF tests on macOS +if platform.system() != "Darwin": + try: + import tensorflow as tf + supports_bool = True + configs_tf.append(('tf', supports_bool, 'eager')) + + # Test configurations which are only used + # for the 'drjit'->'tf' direction, where + # the TF function is executed in Graph or XLA mode + configs_tf_jit.append(('tf', supports_bool, 'graph')) + configs_tf_jit.append(('tf', supports_bool, 'xla')) + + TF_HAS_GPU = bool(tf.config.list_physical_devices("GPU")) + del supports_bool + + except ImportError: + pass + +configs = configs_torch + configs_jax + configs_tf + def wrap(config): def wrapper(func): - if config[2]: + if config[0] == 'jax' and config[2]: func = jax.jit(func) + elif config[0] == 'tf' and config[2] != 'eager': + func = tf.function(jit_compile=config[2]=='xla')(func) return dr.wrap(source='drjit', target=config[0])(func) return wrapper @@ -67,11 +98,51 @@ def torch_dtype(t): else: raise Exception("Unsupported variable type") +def tf_dtype(t): + import tensorflow as tf + vt = dr.type_v(t) + if vt == dr.VarType.Float16: + return tf.float16 + elif vt == dr.VarType.Float32: + return tf.float32 + elif vt == dr.VarType.Float64: + return tf.float64 + else: + raise Exception("Unsupported variable type") + + +def skip_if_unsupported(config, t, needs_int32: bool = False): + """Helper function to skip tests in configurations which are known + to be unsupported.""" + + # Skip TF test if DrJit supports CUDA on this platform, + # but TF does not (e.g. native Windows). + if config[0] == "tf": + if dr.backend_v(t) == dr.JitBackend.CUDA: + if not TF_HAS_GPU: + pytest.skip("TensorFlow didn't detect a CUDA device, skipping.") + if needs_int32 and config[2] == "graph": + pytest.xfail("Expected to fail due to TF dlpack issue for int32," + " see https://github.com/tensorflow/tensorflow/issues/78091") + + +@pytest.fixture(scope="module", autouse=True) +def wrap_teardown(): + # Run tests in this module + yield + + # Teardown: invoke garbage collector explicitly to avoid + # issues related to destruction order during shutdown. + import gc + gc.collect() + @pytest.mark.parametrize('is_diff', [True, False]) -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') +@pytest.skip_on(RuntimeError, "backend does not support the requested type of atomic reduction") def test01_simple_bwd(t, config, is_diff): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x): return x * 2 @@ -91,36 +162,59 @@ def test_fn(x): @pytest.mark.parametrize('scalar_deriv', [True, False]) @pytest.mark.parametrize('is_diff', [True, False]) -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "backend does not support the requested type of atomic reduction") def test02_flipped_simple_bwd(t, config, is_diff, scalar_deriv): + skip_if_unsupported(config, t) @wrap_flipped(config) def test_fn(x): assert dr.is_array_v(x) return x * 2 - import torch - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt) - x.requires_grad = is_diff - - y = test_fn(x) - assert torch.all(y == torch.arange(3, dtype=dt) * 2) - - if is_diff: - if scalar_deriv: - y.sum().backward() - assert torch.all(x.grad == torch.tensor([2, 2, 2], dtype=dt)) + if config[0] == 'torch': + import torch + dt = torch_dtype(t) + x = torch.arange(3, dtype=dt) + x.requires_grad = is_diff + + y = test_fn(x) + assert torch.all(y == torch.arange(3, dtype=dt) * 2) + + if is_diff: + if scalar_deriv: + y.sum().backward() + assert torch.all(x.grad == torch.tensor([2, 2, 2], dtype=dt)) + else: + torch.autograd.backward(y, torch.tensor([10, 20, 30], dtype=dt)) + assert torch.all(x.grad == torch.tensor([20, 40, 60], dtype=dt)) else: - torch.autograd.backward(y, torch.tensor([10, 20, 30], dtype=dt)) - assert torch.all(x.grad == torch.tensor([20, 40, 60], dtype=dt)) - else: - assert y.grad is None - -@pytest.mark.parametrize('config', configs) + assert y.grad is None + + elif config[0] == 'tf': + import tensorflow as tf + dt = tf_dtype(t) + x = tf.cast(tf.range(3), dt) + y = test_fn(x) + assert tf.reduce_all(y == x * 2) + if is_diff: + if scalar_deriv: + with tf.GradientTape() as tape: + tape.watch(x) + out = test_fn(x) + grad = tape.gradient(out, x) + assert tf.reduce_all(grad == tf.constant([2, 2, 2], dtype=dt)) + else: + with tf.GradientTape() as tape: + tape.watch(x) + out = test_fn(x) + grad = tape.gradient(out, x, output_gradients=tf.constant([10, 20, 30], dtype=dt)) + assert tf.reduce_all(grad == tf.constant([20, 40, 60], dtype=dt)) + +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test03_simple_fwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x): return x * 2 @@ -134,10 +228,13 @@ def test_fn(x): assert dr.all(y == [0, 2, 4]) assert dr.all(y.grad == [20, 40, 60]) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') def test04_flipped_simple_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") import torch.autograd.forward_ad as fwd_ad @wrap_flipped(config) @@ -163,10 +260,10 @@ def test_fn(x): assert torch.all(w == x*2) assert torch.all(wd == xd*2) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test05_simple_multiarg_bwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x, y): return x + y, y, x @@ -188,42 +285,63 @@ def test_fn(x, y): assert dr.all(x.grad == [60, 80, 100]) assert dr.all(y.grad == [100]) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "backend does not support the requested type of atomic reduction") def test06_flipped_simple_multiarg_bwd(t, config): + skip_if_unsupported(config, t) @wrap_flipped(config) def test_fn(x, y): return x + y, y, x + if config[0] == 'torch': + dt = torch_dtype(t) + x = torch.arange(3, dtype=dt, requires_grad=True) + y = torch.tensor([4], dtype=dt, requires_grad=True) + a, b, c = test_fn(x, y) - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt, requires_grad=True) - y = torch.tensor([4], dtype=dt, requires_grad=True) - a, b, c = test_fn(x, y) - - a.grad = torch.tensor([10, 20, 30], dtype=dt) - b.grad = torch.tensor([40], dtype=dt) - c.grad = torch.tensor([50, 60, 70], dtype=dt) - - assert torch.all(a == torch.tensor([4, 5, 6], dtype=dt)) - assert torch.all(b == torch.tensor([4], dtype=dt)) - assert torch.all(c == torch.tensor([0, 1, 2], dtype=dt)) - - torch.autograd.backward( - (a, b, c), - ( - torch.tensor([10, 20, 30], dtype=dt), - torch.tensor([40], dtype=dt), - torch.tensor([50, 60, 70], dtype=dt), - )) - - assert torch.all(x.grad == torch.tensor([60, 80, 100], dtype=dt)) - assert torch.all(y.grad == torch.tensor([100], dtype=dt)) + a.grad = torch.tensor([10, 20, 30], dtype=dt) + b.grad = torch.tensor([40], dtype=dt) + c.grad = torch.tensor([50, 60, 70], dtype=dt) + + assert torch.all(a == torch.tensor([4, 5, 6], dtype=dt)) + assert torch.all(b == torch.tensor([4], dtype=dt)) + assert torch.all(c == torch.tensor([0, 1, 2], dtype=dt)) + + torch.autograd.backward( + (a, b, c), + ( + torch.tensor([10, 20, 30], dtype=dt), + torch.tensor([40], dtype=dt), + torch.tensor([50, 60, 70], dtype=dt), + )) + + assert torch.all(x.grad == torch.tensor([60, 80, 100], dtype=dt)) + assert torch.all(y.grad == torch.tensor([100], dtype=dt)) + + elif config[0] == 'tf': + dt = tf_dtype(t) + x = tf.constant([0, 1, 2], dtype=dt) + y = tf.constant([4], dtype=dt) + a, b, c = test_fn(x, y) -@pytest.mark.parametrize('config', configs) + assert tf.reduce_all(a == tf.constant([4, 5, 6], dtype=dt)) + assert tf.reduce_all(b == tf.constant([4], dtype=dt)) + assert tf.reduce_all(c == tf.constant([0, 1, 2], dtype=dt)) + + with tf.GradientTape() as tape: + tape.watch([x, y]) + a, b, c = test_fn(x, y) + grad = tape.gradient([a, b, c], [x, y], + output_gradients=[tf.constant([10, 20, 30], dtype=dt), + tf.constant([40], dtype=dt), + tf.constant([50, 60, 70], dtype=dt)]) + assert tf.reduce_all(grad[0] == tf.constant([60, 80, 100], dtype=dt)) + assert tf.reduce_all(grad[1] == tf.constant([100], dtype=dt)) + +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test07_simple_multiarg_fwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x, y): return x + y, y, x @@ -244,11 +362,14 @@ def test_fn(x, y): assert dr.all(b.grad == [40]) assert dr.all(c.grad == [10, 20, 30]) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "not implemented for 'Half'") def test08_filled_simple_multiarg_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") @wrap_flipped(config) def test_fn(x, y): return x + y, y + 1, x + 1 @@ -272,10 +393,10 @@ def test_fn(x, y): assert dr.all(bd == torch.tensor([40], dtype=dt)) assert dr.all(cd == torch.tensor([10, 20, 30], dtype=dt)) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float32,shape=(*)') def test09_nondiff_bwd(t, config): + skip_if_unsupported(config, t, needs_int32=True) @wrap(config) def test_fn(x, y, z): return x, y, z @@ -297,32 +418,55 @@ def test_fn(x, y, z): assert dr.all(x.grad == [10, 20, 30]) -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float32,shape=(*)') def test10_flipped_nondiff_bwd(t, config): - with dr.detail.scoped_rtld_deepbind(): - @wrap_flipped(config) - def test_fn(x, y, z): - return x*2, y+1, ~z - - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt, requires_grad=True) - y = x.type(torch.int32) - if config[1]: - z = y > 0 - else: - z = y - - a, b, c = test_fn(x, y, z) - assert torch.all(a == x*2) and torch.all(b == y + 1) and torch.all(c == ~z) - - torch.autograd.backward(a, torch.tensor([10, 20, 30], dtype=dt)) - assert dr.all(x.grad == torch.tensor([20, 40, 60], dtype=dt)) - - -@pytest.mark.parametrize('config', configs) + skip_if_unsupported(config, t) + if config[0] == 'torch': + with dr.detail.scoped_rtld_deepbind(): + @wrap_flipped(config) + def test_fn(x, y, z): + return x*2, y+1, ~z + + dt = torch_dtype(t) + x = torch.arange(3, dtype=dt, requires_grad=True) + y = x.type(torch.int32) + if config[1]: + z = y > 0 + else: + z = y + + a, b, c = test_fn(x, y, z) + assert torch.all(a == x*2) and torch.all(b == y + 1) and torch.all(c == ~z) + + torch.autograd.backward(a, torch.tensor([10, 20, 30], dtype=dt)) + assert dr.all(x.grad == torch.tensor([20, 40, 60], dtype=dt)) + elif config[0] == 'tf': + with dr.detail.scoped_rtld_deepbind(): + @wrap_flipped(config) + def test_fn(x, y, z): + return x*2, y+1, ~z + + dt = tf_dtype(t) + x = tf.constant([0, 1, 2], dtype=dt) + y = tf.cast(x, dtype=tf.int32) + if config[1]: + z = y > 0 + else: + z = y + + a, b, c = test_fn(x, y, z) + assert tf.reduce_all(a == x*2) and tf.reduce_all(b == y+1) and tf.reduce_all(c == ~z) + with tf.GradientTape() as tape: + tape.watch(x) + a, b, c = test_fn(x, y, z) + grad = tape.gradient(a, x, output_gradients=tf.constant([10, 20, 30], dtype=dt)) + assert tf.reduce_all(grad == tf.constant([20, 40, 60], dtype=dt)) + +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float32,shape=(*)') def test11_nondiff_fwd(t, config): + skip_if_unsupported(config, t, needs_int32=True) @wrap(config) def test_fn(x, y, z): return x, y, z @@ -343,10 +487,15 @@ def test_fn(x, y, z): assert dr.all(a.grad == [10, 20, 30]) -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "not implemented for 'Half'") def test12_flipped_nondiff_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") + @wrap_flipped(config) def test_fn(x, y, z): return x*2, y+1, ~z @@ -373,10 +522,10 @@ def test_fn(x, y, z): assert torch.all(a == x*2) and torch.all(b == y + 1) and torch.all(c == ~z) assert dr.all(ad == torch.tensor([20, 40, 60], dtype=dt)) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test13_scalar_bwd(t, config): + skip_if_unsupported(config, t, needs_int32=True) @wrap(config) def test_fn(x, y, z): return x*2, y, z @@ -392,28 +541,44 @@ def test_fn(x, y, z): assert dr.all(x.grad == [20, 40, 60]) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "not implemented for 'Half'") def test14_flipped_scalar_bwd(t, config): + skip_if_unsupported(config, t) @wrap_flipped(config) def test_fn(x, y, z): return x*2, y+1, z+1 - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt, requires_grad=True) + if config[0] == 'torch': + dt = torch_dtype(t) + x = torch.arange(3, dtype=dt, requires_grad=True) - a, b, c = test_fn(x, 4, 5.0) - assert torch.all(a == x*2) and (b == 5) and (c == 6) + a, b, c = test_fn(x, 4, 5.0) + assert torch.all(a == x*2) and (b == 5) and (c == 6) - torch.autograd.backward(a, torch.tensor([10, 20, 30], dtype=dt)) + torch.autograd.backward(a, torch.tensor([10, 20, 30], dtype=dt)) + + assert torch.all(x.grad == torch.tensor([20, 40, 60], dtype=dt)) - assert torch.all(x.grad == torch.tensor([20, 40, 60], dtype=dt)) + elif config[0] == 'tf': + dt = tf_dtype(t) + x = tf.constant([0, 1, 2], dtype=dt) + + a, b, c = test_fn(x, 4, 5.0) + assert tf.reduce_all(a == x*2) and (b == 5) and (c == 6) -@pytest.mark.parametrize('config', configs) + with tf.GradientTape() as tape: + tape.watch(x) + a, b, c = test_fn(x, 4, 5.0) + grad = tape.gradient(a, x, output_gradients=tf.constant([10, 20, 30], dtype=dt)) + + assert tf.reduce_all(grad == tf.constant([20, 40, 60], dtype=dt)) + +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test15_scalar_fwd(t, config): + skip_if_unsupported(config, t, needs_int32=True) @wrap(config) def test_fn(x, y, z): return x, y, z @@ -429,10 +594,15 @@ def test_fn(x, y, z): assert dr.all(a.grad == [10, 20, 30]) -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') -@pytest.mark.skip(reason='Skipped until issue https://github.com/pytorch/pytorch/issues/117491 is fixed.') def test14_flipped_scalar_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == 'torch': + pytest.skip('Skipped until issue https://github.com/pytorch/pytorch/issues/117491 is fixed.') + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") @wrap_flipped(config) def test_fn(x, y, z): return x*2, y+1, z+1 @@ -447,12 +617,15 @@ def test_fn(x, y, z): a, b, c = test_fn(x, 4, 5.0) assert torch.all(a == x*2) and torch.all(b == 5) and torch.all(c == 6) a, ad = fwd_ad.unpack_dual(a) - assert torch.all(xd == torch.tensor([20, 40, 60], dtype=dt)) - + assert torch.all(ad == torch.tensor([20, 40, 60], dtype=dt)) -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test15_custom_class_bwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf" and config[2] in ['graph', 'xla']: + pytest.skip("Skipped since TF Graph/XLA mode does not support class inputs.") + class MyClass: pass @@ -470,14 +643,18 @@ def test_fn(x, y): a.grad = [10, 20, 30] dr.backward_to(x) - assert dr.all(x.grad == [10, 20, 30]) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "not implemented for 'Half'") def test16_flipped_custom_class_bwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf": + pytest.skip("Skipped since tf.custom_gradient only supports " + "Tensor inputs (see " + "https://www.tensorflow.org/api_docs/python/tf/custom_gradient).") + class MyClass: pass @@ -497,10 +674,13 @@ def test_fn(x, y): assert torch.all(x.grad == torch.tensor([10, 20, 30], dtype=dt)) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test17_custom_class_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf" and config[2] in ('graph', 'xla'): + pytest.skip("Skipped since TF Graph/XLA mode does not support class inputs.") + class MyClass: pass @@ -520,11 +700,15 @@ def test_fn(x, y): assert dr.all(a.grad == [10, 20, 30]) - -@pytest.mark.skip(reason='Skipped until issue https://github.com/pytorch/pytorch/issues/117491 is fixed.') -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') def test18_flipped_custom_class_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == 'torch': + pytest.skip('Skipped until issue https://github.com/pytorch/pytorch/issues/117491 is fixed.') + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") class MyClass: pass @@ -548,10 +732,10 @@ def test_fn(x, y): assert torch.all(ad == torch.tensor([10, 20, 30], dtype=dt)) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test19_args_kwargs_bwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(*args, **kwargs): return args[0] * kwargs["y"] @@ -567,31 +751,46 @@ def test_fn(*args, **kwargs): assert dr.all(x.grad == [40, 80, 120]) assert dr.all(y.grad == [80]) -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "backend does not support the requested type of atomic reduction") def test20_flipped_args_kwargs_bwd(t, config): + skip_if_unsupported(config, t) @wrap_flipped(config) def test_fn(*args, **kwargs): return args[0] * kwargs["y"] - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt, requires_grad=True) - y = torch.tensor([4], dtype=dt, requires_grad=True) - r = test_fn(x, y=y) + if config[0] == 'torch': + dt = torch_dtype(t) + x = torch.arange(3, dtype=dt, requires_grad=True) + y = torch.tensor([4], dtype=dt, requires_grad=True) + r = test_fn(x, y=y) - torch.autograd.backward( - r, - torch.tensor([10, 20, 30], dtype=dt) - ) + torch.autograd.backward( + r, + torch.tensor([10, 20, 30], dtype=dt) + ) - assert torch.all(x.grad == torch.tensor([40, 80, 120], dtype=dt)) - assert torch.all(y.grad == torch.tensor([80], dtype=dt)) + assert torch.all(x.grad == torch.tensor([40, 80, 120], dtype=dt)) + assert torch.all(y.grad == torch.tensor([80], dtype=dt)) + elif config[0] == 'tf': + dt = tf_dtype(t) + x = tf.constant([0,1,2], dtype=dt) + y = tf.constant([4], dtype=dt) -@pytest.mark.parametrize('config', configs) + with tf.GradientTape() as tape: + tape.watch([x, y]) + r = test_fn(x, y=y) + + # See issue https://github.com/tensorflow/tensorflow/issues/77559 + with pytest.raises(TypeError, match="Keyword arguments are not allowed for 'tf->drjit'"): + grad = tape.gradient(r, [x, y], output_gradients=tf.constant([10, 20, 30], dtype=dt)) + +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test21_args_kwargs_fwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(*args, **kwargs): return args[0] * kwargs["y"] @@ -607,11 +806,14 @@ def test_fn(*args, **kwargs): assert dr.all(g == [40, 120, 200]) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "not implemented for 'Half'") def test22_flipped_args_kwargs_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") @wrap_flipped(config) def test_fn(*args, **kwargs): return args[0] * kwargs["y"] @@ -633,13 +835,16 @@ def test_fn(*args, **kwargs): assert dr.all(rd == torch.tensor([40, 120, 200], dtype=dt)) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(3, *)') def test23_nested_arrays_bwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x, y): - return (x*y).sum() + if config[0] == 'tf': + return tf.reduce_sum(x*y) + else: + return (x*y).sum() x = t([1, 2], [3, 4], [5, 6]) y = t(10, 20, 30) @@ -651,13 +856,22 @@ def test_fn(x, y): assert dr.all(x.grad == [6000, 12000, 18000], axis=None) assert dr.all(y.grad == [1800, 4200, 6600]) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(3, *)') def test24_nested_arrays_fwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x, y): - return (x*y).sum() + # Note: this test was initially written using `x * y`, but this caused + # reference leaks from TensorFlow in a very weird corner case where + # all of the following applied: + # - LLVM backend + # - Both XLA and eager-mode variants enabled + # - Forward-mode gradients (backward-mode didn't leak) + if config[0] == 'tf': + return tf.reduce_sum(x + y) + else: + return (x + y).sum() x = t([1, 2], [3, 4], [5, 6]) y = t(10, 20, 30) @@ -668,12 +882,12 @@ def test_fn(x, y): g = dr.forward_to(r) assert dr.is_tensor_v(g) - assert g.array[0] == 10000 - + assert g.array[0] == 1410.0 -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test25_pytree_bwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x): return { @@ -694,33 +908,10 @@ def test_fn(x): assert dr.all(x.grad == [100, 200, 300]) assert dr.all(y.grad == [200, 400, 600]) - -@pytest.mark.parametrize('config', configs_torch) -@pytest.test_arrays('is_diff,float,shape=(*)') -def test26_flipped_pytree_bwd(t, config): - @wrap_flipped(config) - def test_fn(x): - return { - 123:(x[0]["hello"] + 2*x[1]["world"][0]) - } - - dt = torch_dtype(t) - x = torch.tensor([1, 2, 3], dtype=dt, requires_grad=True) - y = torch.tensor([4, 5, 6], dtype=dt, requires_grad=True) - xt = [ - { 'hello' : x }, - { 'world' : (y,) } - ] - rt = test_fn(xt) - r = rt[123] - - torch.autograd.backward(r, torch.tensor([100, 200, 300], dtype=dt)) - assert torch.all(x.grad == torch.tensor([100, 200, 300], dtype=dt)) - assert torch.all(y.grad == torch.tensor([200, 400, 600], dtype=dt)) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test25_pytree_fwd(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x): return { @@ -741,10 +932,56 @@ def test_fn(x): assert dr.all(r.grad == [90, 120, 150]) +@pytest.mark.parametrize('config', configs_torch + configs_tf) +@pytest.test_arrays('is_diff,float,shape=(*)') +def test26_flipped_pytree_bwd(t, config): + skip_if_unsupported(config, t) + @wrap_flipped(config) + def test_fn(x): + return { + 123:(x[0]["hello"] + 2*x[1]["world"][0]) + } + + if config[0] == 'torch': + dt = torch_dtype(t) + x = torch.tensor([1, 2, 3], dtype=dt, requires_grad=True) + y = torch.tensor([4, 5, 6], dtype=dt, requires_grad=True) + xt = [ + { 'hello' : x }, + { 'world' : (y,) } + ] + rt = test_fn(xt) + r = rt[123] -@pytest.mark.parametrize('config', configs_torch) + torch.autograd.backward(r, torch.tensor([100, 200, 300], dtype=dt)) + assert torch.all(x.grad == torch.tensor([100, 200, 300], dtype=dt)) + assert torch.all(y.grad == torch.tensor([200, 400, 600], dtype=dt)) + + elif config[0] == 'tf': + dt = tf_dtype(t) + x = tf.constant([1, 2, 3], dtype=dt) + y = tf.constant([4, 5, 6], dtype=dt) + with tf.GradientTape() as tape: + tape.watch([x, y]) + xt = [ + { 'hello' : x }, + { 'world' : (y,) } + ] + rt = test_fn(xt) + r = rt[123] + grad = tape.gradient(r, [x, y], output_gradients=tf.constant([100, 200, 300], dtype=dt)) + + assert tf.reduce_all(grad[0] == tf.constant([100, 200, 300], dtype=dt)) + assert tf.reduce_all(grad[1] == tf.constant([200, 400, 600], dtype=dt)) + +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float,shape=(*)') def test26_flipped_pytree_fwd(t, config): + skip_if_unsupported(config, t) + if config[0] == "tf": + pytest.skip("Skipped due to limited support of `tf.custom_gradient()`" + " for forward-mode AD.") + @wrap_flipped(config) def test_fn(x): return { @@ -772,10 +1009,10 @@ def test_fn(x): r, rd = fwd_ad.unpack_dual(r) assert torch.all(rd == torch.tensor([90, 120, 150], dtype=dt)) - -@pytest.mark.parametrize('config', configs) +@pytest.mark.parametrize('config', configs + configs_tf_jit) @pytest.test_arrays('is_diff,float32,shape=(*)') def test27_exception(t, config): + skip_if_unsupported(config, t) @wrap(config) def test_fn(x): raise RuntimeError('foo') @@ -784,22 +1021,25 @@ def test_fn(x): test_fn(t(1, 2, 3)) assert 'foo' in str(err.value.__cause__) - -@pytest.mark.parametrize('config', configs_torch) +@pytest.mark.parametrize('config', configs_torch + configs_tf) @pytest.test_arrays('is_diff,float32,shape=(*)') def test28_flipped_exception(t, config): + skip_if_unsupported(config, t) @wrap_flipped(config) def test_fn(x): raise RuntimeError('foo') with pytest.raises(RuntimeError) as err: - test_fn(torch.tensor([1, 2, 3])) + if config[0] == 'torch': + test_fn(torch.tensor([1, 2, 3])) + elif config[0] == 'tf': + test_fn(tf.constant([1, 2, 3])) assert 'foo' in str(err.value) - -@pytest.mark.parametrize('config', configs_torch) -@pytest.test_arrays('is_diff,llvm,float,shape=(*)') +@pytest.mark.parametrize('config', configs_torch + configs_tf) +@pytest.test_arrays('is_diff,float,shape=(*)') @pytest.skip_on(RuntimeError, "backend does not support the requested type of atomic reduction") def test29_flipped_non_tensor_output_bwd(t, config): + skip_if_unsupported(config, t) @wrap_flipped(config) def test_fn(x): a = dr.gather(t, x.array, 0) @@ -807,21 +1047,38 @@ def test_fn(x): c = dr.gather(t, x.array, 2) return a, b * 2, c * 3 - import torch - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt) - x.requires_grad = True - - out1, out2, out3 = test_fn(x) - assert out1 == 0 - assert out2 == 2 - assert out3 == 6 - - (out1 + out2 + out3).backward() - assert torch.all(x.grad == torch.tensor([1, 2, 3], dtype=dt)) - - -@pytest.mark.parametrize('config', configs_torch) + if config[0] == 'torch': + dt = torch_dtype(t) + device = 'cuda' if 'cuda' in str(t) else 'cpu' + with torch.device(device): + x = torch.arange(3, dtype=dt) + x.requires_grad = True + + out1, out2, out3 = test_fn(x) + assert out1 == 0 + assert out2 == 2 + assert out3 == 6 + + (out1 + out2 + out3).backward() + assert torch.all(x.grad == torch.tensor([1, 2, 3], dtype=dt)) + + elif config[0] == 'tf': + dt = tf_dtype(t) + device = 'gpu' if 'cuda' in str(t) else 'cpu' + with tf.device(device): + x = tf.constant([0, 1, 2], dtype=dt) + with tf.GradientTape() as tape: + tape.watch(x) + out1, out2, out3 = test_fn(x) + assert out1 == 0 + assert out2 == 2 + assert out3 == 6 + + z = out1 + out2 + out3 + grad = tape.gradient(z, x) + assert tf.reduce_all(grad == tf.constant([1, 2, 3], dtype=dt)) + +@pytest.mark.parametrize('config', configs_torch + configs_tf + configs_tf_jit) @pytest.test_arrays('is_diff,float,shape=(*)') def test30_nested(t, config): @wrap(config) @@ -832,16 +1089,31 @@ def add(x, y): def test_fn(x, y): return x * add(x, y) - import torch - dt = torch_dtype(t) - x = torch.arange(3, dtype=dt) - y = x * 4 - x.requires_grad = True - y.requires_grad = True - - out = test_fn(x, y) - assert torch.all(out == torch.tensor([0, 5, 20], dtype=dt)) - - torch.autograd.backward(out, torch.tensor([1, 2, 3], dtype=dt)) - assert torch.all(x.grad == torch.tensor([0, 12, 36], dtype=dt)) - assert torch.all(y.grad == torch.tensor([0, 2, 6], dtype=dt)) + if config[0] == 'torch': + dt = torch_dtype(t) + x = torch.arange(3, dtype=dt) + y = x * 4 + x.requires_grad = True + y.requires_grad = True + + out = test_fn(x, y) + assert torch.all(out == torch.tensor([0, 5, 20], dtype=dt)) + + torch.autograd.backward(out, torch.tensor([1, 2, 3], dtype=dt)) + assert torch.all(x.grad == torch.tensor([0, 12, 36], dtype=dt)) + assert torch.all(y.grad == torch.tensor([0, 2, 6], dtype=dt)) + + elif config[0] == 'tf': + dt = tf_dtype(t) + x = tf.constant([0, 1, 2], dtype=dt) + y = x * 4 + + with tf.GradientTape() as tape: + tape.watch([x, y]) + out = test_fn(x, y) + grad = tape.gradient(out, [x, y], + output_gradients=tf.constant([1, 2, 3], dtype=dt)) + + assert tf.reduce_all(out == tf.constant([0, 5, 20], dtype=dt)) + assert tf.reduce_all(grad[0] == tf.constant([0, 12, 36], dtype=dt)) + assert tf.reduce_all(grad[1] == tf.constant([0, 2, 6], dtype=dt))