Skip to content

Stop action gradient in policy gradient loss #109

@danijar

Description

@danijar

The current implementation of policy_gradient_loss is:

log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
loss_per_timestep = -log_pi_a_t * adv_t

It's good that the gradients are already stopped around the advantages, but they should also be stopped around the actions to ensure an unbiased gradient estimator.

This is important when the actions are sampled as part of the training graph (MPO-style algos, imagination training with world models) rather than coming from the replay buffer, and the actor distribution implements a gradient for sample() (e.g. gaussian, or straight-through categoricals).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions