Feature request: Add an unroll argument (default 1) to functions that use lax.scan.
scan is known to be slow on GPUs because it requires a kernel launch on each iteration, unless multiple iterations are unrolled.
This change would allow users to do that, and thus potentially speed up computation.
I can submit a PR for this.