-
Notifications
You must be signed in to change notification settings - Fork 100
Open
Description
The current implementation of policy_gradient_loss is:
log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
loss_per_timestep = -log_pi_a_t * adv_tIt's good that the gradients are already stopped around the advantages, but they should also be stopped around the actions to ensure an unbiased gradient estimator.
This is important when the actions are sampled as part of the training graph (MPO-style algos, imagination training with world models) rather than coming from the replay buffer, and the actor distribution implements a gradient for sample() (e.g. gaussian, or straight-through categoricals).
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels