how do I interpret these traces? #31039
              
                Unanswered
              
          
                  
                    
                      jerry-kobold
                    
                  
                
                  asked this question in
                Q&A
              
            Replies: 1 comment 1 reply
-
| 
         hey, check this and let me know ##For diagnosis, - ##& to separate compile 4m execute on the 1st :  | 
  
Beta Was this translation helpful? Give feedback.
                  
                    1 reply
                  
                
            
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
        
    
Uh oh!
There was an error while loading. Please reload this page.
-
Hello JAX communit! I'm fairly new to JAX, and am trying to use it to do some physics simulations by connecting an existing model to the
blackjaxpackage. I have a function that looks like this:I've omitted the internals of the
log_likelihoodfunction because that is based on our internal code. But the function computes the right result (as defined by our internal tests) and isjit-able. This function implements one forward step of a model that I want to now hook intoblackjaxto do HMC.When I trace one run of the step like so:
I get the following trace:

I've highlighted a bit from the trace that centers around the
ThunkExecutor::Executemethod in the client thread, which helpfully tells me that it's waiting for the computation to complete. One thing that jumps out at me is that theThunkExecutor::Executemethods in the worker threads seem to be a bit sparse, but ok, fine. Seems generally reasonable.Now I hook this up to the
blackjaxframework (using their example) and try to run the warmup step to determine the inverse mass matrix and step size:This runs just fine but seems to be very slow. When I look at the profiling output, I see the following:
I am using a 32-core CPU as my device, so this is a subset but what I am struggling to understand here is why the
ThunkExecutor::Executethat is shown there in brown is waiting so long, when it looks naively like the computation is finished. What am I missing here and how can I interpret what's happening? I realize that this referencesblackjaxbut I think my question is more about how I should understand the difference between these traces rather than the specifics of the library.Beta Was this translation helpful? Give feedback.
All reactions