-
Notifications
You must be signed in to change notification settings - Fork 14
Description
Hello, I am a little confused (?) about the normalization part for the discriminator here in the file amp_ppo.py, and it may be a bug.
policy_state, policy_next_state = sample_amp_policy # (len, 43)
expert_state, expert_next_state = sample_amp_expert
if self.amp_normalizer is not None:
with torch.no_grad():
policy_state = self.amp_normalizer.normalize_torch(policy_state, self.device)
policy_next_state = self.amp_normalizer.normalize_torch(policy_next_state, self.device)
expert_state = self.amp_normalizer.normalize_torch(expert_state, self.device)
expert_next_state = self.amp_normalizer.normalize_torch(expert_next_state, self.device)
policy_d = self.discriminator(torch.cat([policy_state, policy_next_state], dim=-1))
expert_d = self.discriminator(torch.cat([expert_state, expert_next_state], dim=-1))In line 240-249 shown above, sample_amp_policy is raw amp expert data before normalization and expert_state is the one after normalization. The data after normalization is used for calculation in the discriminator.
However, when calculating gradient penalty in line 270-281 below, sample_amp_expert is used directly. According to the original paper, input for gradient penalty should come from expert data. These gradient penalty will minimize the gradients on other parts of the data manifold instead of the expert data as expected.
sample_amp_expert = torch.cat(sample_amp_expert, dim=-1)
sample_amp_expert.requires_grad = True
disc = self.discriminator.amp_linear(self.discriminator.trunk(sample_amp_expert))
ones = torch.ones(disc.size(), device=disc.device)
disc_demo_grad = torch.autograd.grad(disc, sample_amp_expert,
grad_outputs=ones,
create_graph=True, retain_graph=True, only_inputs=True)
disc_demo_grad = disc_demo_grad[0]
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
grad_pen_loss = self.disc_grad_penalty * torch.mean(disc_demo_grad)
# amp_loss += 0.2 * disc_grad_penalty # self._disc_grad_penalty:0.2Finally when updating normalizer, the normalized data is used in line 308-310 as shown below. According to my understanding to the code, data before normalization should be used here to estimate the running mean and std. This part may lead to wrong mean and std, but won't have too large impact on the results.
if self.amp_normalizer is not None:
self.amp_normalizer.update(policy_state.cpu().numpy())
self.amp_normalizer.update(expert_state.cpu().numpy())