-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I am trying to use a function that requires a gradient to be in higher precision on the backward pass than on a forward pass. A minimal reproducible example of this is shown below:
import jax
import jax.numpy as jnp
# Dummy function that requires float32 gradient
@jax.custom_vjp
def identity(x):
return x
def identity_fwd(x):
return x, None
def identity_bwd(res, g):
return (g.astype(jnp.float32),)
identity.defvjp(identity_fwd, identity_bwd)
def model(x):
def scannable_identity_fn(carry, _):
return identity(carry), None
x = identity(x) # without these identity calls, the scan in/out types are different
x = jax.lax.scan(scannable_identity_fn, x, None, length=2)[0]
x = identity(x)
return x
# Running with x in bfloat16 fails due to scan, as the gradient is fp32
x = jnp.asarray(8, dtype=jnp.bfloat16)
grad = jax.grad(model)(x)
print(grad)When using this with a scan I get the following error.
---------------------------------------------------------------------------
MLIRError Traceback (most recent call last)
File ~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1212, in lower_jaxpr_to_module(***failed resolving arguments***)
[1211](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1211) try:
-> [1212](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1212) if not ctx.module.operation.verify():
[1213](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1213) raise ValueError(
[1214](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1214) "Cannot lower jaxpr with verifier errors. " +
[1215](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1215) dump_module_message(ctx.module, "verification"))
MLIRError: Verification failed:
error: "jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))): 'func.call' op operand type mismatch: expected operand type 'tensor<bf16>', but provided 'tensor<f32>' for operand number 0
note: "jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))): see current operation: %6 = "func.call"(%arg2) <{callee = @None}> : (tensor<f32>) -> tensor<f32>
The above exception was the direct cause of the following exception:
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File ~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel_launcher.py:18
[16](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel_launcher.py:16) from ipykernel import kernelapp as app
---> [18](https://file+.vscode-resource.vscode-cdn.net/Users/liam/PycharmProjects/waffle/~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel_launcher.py:18) app.launch_new_instance()
File ~/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance()
...
'func.call' op operand type mismatch: expected operand type 'tensor<bf16>', but provided 'tensor<f32>' for operand number 0
at loc("jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))))
see current operation: %6 = "func.call"(%arg2) <{callee = @None}> : (tensor<f32>) -> tensor<f32>
at loc("jit(scan)/jit(main)/while/body/closed_call"(callsite("model"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":26:8) at callsite("<module>"("/var/folders/84/ltzfvw692yl9h64qv9g44c5c0000gn/T/ipykernel_75635/2738124691.py":33:7) at callsite("InteractiveShell.run_code"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at "IPythonKernel.do_execute"("/Users/liam/PycharmProjects/waffle/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26))))))))))))
The module was dumped to jax_ir0034_jit_scan_verification.mlir.
While I can resolve this by using the same dtype on the forward and backward pass, I am using a similar setup to perform quantization of activations of a forward pass, but trying to keep gradients in a higher precision. Is this a limitation of using scan or is there a way to achieve different types for fwd and bwd in a scan?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.0
jaxlib: 0.5.0
numpy: 2.0.2
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct 2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Liams-Work-MacBook-Pro.local', release='24.3.0', version='Darwin Kernel Version 24.3.0: Thu Jan 2 20:24:24 PST 2025; root:xnu-11215.81.4~3/RELEASE_ARM64_T6030', machine='arm64')
Note: same issue occurs on other platforms
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working