Replies: 1 comment
-
|
Bump? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there, thanks for your work and looking at this. I am trying to port my environment definition to JAX with the hopes that I can run massively parallel environment rollouts for my RL algorithm.
The primary goal is to navigate in a 3D maze. Ignoring the details for now, I based my implementation upon Gym and Gymnax conventions here: https://github.com/m-krastev/gymnax
To make use of massive parallelism, I rely on my whole code being JIT-able in reasonable margins. In this case, the code can be compiled. However, it runs around 1-2x slower compared to a similar code written in numpy + Torch (with most work done on CPU). Upon profiling, I see that a huge amount of time is spent by memcpyD2D, and I struggle to find what could cause the compiler to not in-place operations.
I provide my main code (gymnax/environments/medical/small_bowel.py) with some parts omitted just to fit within reasonable margins. SmallBowelParams are meant to be immutable once instantiated and only the state being passed around with new values.
Environment code:
One particularly expensive function both in Torch and in JAX is a line drawing algorithm, which generates the coordinates of a line in 3D space and then expands it to fit within a certain radius. I tried two approaches, one using convolutions to approximate dilation and another by drawing spheres. Here is the second approach (which is slightly faster), but the first one is also available in my code.
Beta Was this translation helpful? Give feedback.
All reactions