Skip to content

The normalization for the discriminator #1

@cyoahs

Description

@cyoahs

Hello, I am a little confused (?) about the normalization part for the discriminator here in the file amp_ppo.py, and it may be a bug.

                policy_state, policy_next_state = sample_amp_policy  # (len, 43)
                expert_state, expert_next_state = sample_amp_expert
                if self.amp_normalizer is not None:
                    with torch.no_grad():
                        policy_state = self.amp_normalizer.normalize_torch(policy_state, self.device)
                        policy_next_state = self.amp_normalizer.normalize_torch(policy_next_state, self.device)
                        expert_state = self.amp_normalizer.normalize_torch(expert_state, self.device)
                        expert_next_state = self.amp_normalizer.normalize_torch(expert_next_state, self.device)
                policy_d = self.discriminator(torch.cat([policy_state, policy_next_state], dim=-1))
                expert_d = self.discriminator(torch.cat([expert_state, expert_next_state], dim=-1))

In line 240-249 shown above, sample_amp_policy is raw amp expert data before normalization and expert_state is the one after normalization. The data after normalization is used for calculation in the discriminator.

However, when calculating gradient penalty in line 270-281 below, sample_amp_expert is used directly. According to the original paper, input for gradient penalty should come from expert data. These gradient penalty will minimize the gradients on other parts of the data manifold instead of the expert data as expected.

                sample_amp_expert = torch.cat(sample_amp_expert, dim=-1)
                sample_amp_expert.requires_grad = True
                disc = self.discriminator.amp_linear(self.discriminator.trunk(sample_amp_expert))
                ones = torch.ones(disc.size(), device=disc.device)
                disc_demo_grad = torch.autograd.grad(disc, sample_amp_expert,
                                                     grad_outputs=ones,
                                                     create_graph=True, retain_graph=True, only_inputs=True)
                disc_demo_grad = disc_demo_grad[0]
                disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
                grad_pen_loss = self.disc_grad_penalty * torch.mean(disc_demo_grad)
                # amp_loss += 0.2 * disc_grad_penalty  # self._disc_grad_penalty:0.2

Finally when updating normalizer, the normalized data is used in line 308-310 as shown below. According to my understanding to the code, data before normalization should be used here to estimate the running mean and std. This part may lead to wrong mean and std, but won't have too large impact on the results.

                if self.amp_normalizer is not None:
                    self.amp_normalizer.update(policy_state.cpu().numpy())
                    self.amp_normalizer.update(expert_state.cpu().numpy())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions