Skip to content

Issue with scan + custom_vjp where backward pass is in higher precision than forward pass #27144

@liamclarkza

Description

@liamclarkza

Description

I am trying to use a function that requires a gradient to be in higher precision on the backward pass than on a forward pass. A minimal reproducible example of this is shown below:

import jax
import jax.numpy as jnp

# Dummy function that requires float32 gradient
@jax.custom_vjp
def identity(x):
    return x

def identity_fwd(x):
    return x, None

def identity_bwd(res, g):
    return (g.astype(jnp.float32),)

identity.defvjp(identity_fwd, identity_bwd)


def model(x):

    def scannable_identity_fn(carry, _):
        return identity(carry), None
    
    x = identity(x) # without these identity calls, the scan in/out types are different
    x = jax.lax.scan(scannable_identity_fn, x, None, length=2)[0]
    x = identity(x)
    return x


# Running with x in bfloat16 fails due to scan, as the gradient is fp32
x = jnp.asarray(8, dtype=jnp.bfloat16)
grad = jax.grad(model)(x)
print(grad)

When using this with a scan I get the following error.

---------------------------------------------------------------------------
MLIRError                                 Traceback (most recent call last)
File ~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1212, in lower_jaxpr_to_module(***failed resolving arguments***)
   [1211](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1211) try:
-> [1212](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1212)   if not ctx.module.operation.verify():
   [1213](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1213)     raise ValueError(
   [1214](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1214)         "Cannot lower jaxpr with verifier errors. " +
   [1215](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1215)         dump_module_message(ctx.module, "verification"))

MLIRError: Verification failed:
error: "jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))): 'func.call' op operand type mismatch: expected operand type 'tensor<bf16>', but provided 'tensor<f32>' for operand number 0
 note: "jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))): see current operation: %6 = "func.call"(%arg2) <{callee = @None}> : (tensor<f32>) -> tensor<f32>

The above exception was the direct cause of the following exception:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel_launcher.py:18
     [16](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel_launcher.py:16) from ipykernel import kernelapp as app
---> [18](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel_launcher.py:18) app.launch_new_instance()

File ~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance()
...
	'func.call' op operand type mismatch: expected operand type 'tensor<bf16>', but provided 'tensor<f32>' for operand number 0
		at loc("jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))))
	see current operation: %6 = "func.call"(%arg2) <{callee = @None}> : (tensor<f32>) -> tensor<f32>
		at loc("jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))))
The module was dumped to jax_ir0034_jit_scan_verification.mlir.

While I can resolve this by using the same dtype on the forward and backward pass, I am using a similar setup to perform quantization of activations of a forward pass, but trying to keep gradients in a higher precision. Is this a limitation of using scan or is there a way to achieve different types for fwd and bwd in a scan?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.0.2
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct  2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Liams-Work-MacBook-Pro.local', release='24.3.0', version='Darwin Kernel Version 24.3.0: Thu Jan  2 20:24:24 PST 2025; root:xnu-11215.81.4~3/RELEASE_ARM64_T6030', machine='arm64')

Note: same issue occurs on other platforms

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions