Skip to content

Discrepancy between the loss mentioned in the paper and GitHub #3

@bhattg

Description

@bhattg

According to the paper, the negative component of the contrastive loss is the difference between the negative states (randomly sampled from embedding at timestamp t, (z_{t}~)) and the ground truth state (z_{t+1}).

However, as per the line 113 of modules.py, given no trans, you are effectively taking the difference between randomly sampled from embedding at timestamp t (z_{t}~) and z_{t} (rather than z_{t+1}).

` def contrastive_loss(self, obs, action, next_obs):

    objs = self.obj_extractor(obs)
    next_objs = self.obj_extractor(next_obs)

    state = self.obj_encoder(objs)
    next_state = self.obj_encoder(next_objs)

    # Sample negative state across episodes at random
    batch_size = state.size(0)
    perm = np.random.permutation(batch_size)
    neg_state = state[perm]

    self.pos_loss = self.energy(state, action, next_state)
    zeros = torch.zeros_like(self.pos_loss)
    
    self.pos_loss = self.pos_loss.mean()
    self.neg_loss = torch.max(
        zeros, self.hinge - self.energy(
            state, action, neg_state, no_trans=True)).mean()

    loss = self.pos_loss + self.neg_loss

    return loss

`
Thus, I feel instead of the state as the first argument of the energy function, next_state should have been the argument. Please let me know if I am misconstruing at any point.

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions