-
Notifications
You must be signed in to change notification settings - Fork 65
Open
Description
According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z_{t}~)) and the ground truth state (z_{t+1}).
However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z_{t}~) and z_{t} (rather than z_{t+1}).
` def contrastive_loss(self, obs, action, next_obs):
objs = self.obj_extractor(obs)
next_objs = self.obj_extractor(next_obs)
state = self.obj_encoder(objs)
next_state = self.obj_encoder(next_objs)
# Sample negative state across episodes at random
batch_size = state.size(0)
perm = np.random.permutation(batch_size)
neg_state = state[perm]
self.pos_loss = self.energy(state, action, next_state)
zeros = torch.zeros_like(self.pos_loss)
self.pos_loss = self.pos_loss.mean()
self.neg_loss = torch.max(
zeros, self.hinge - self.energy(
state, action, neg_state, no_trans=True)).mean()
loss = self.pos_loss + self.neg_loss
return loss
`
Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.
Thanks.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels