|
t1 = (samples_var.pow(2) + samples_mean.pow(2)) / 2 |
|
t2 = -samples_var.log() |
|
|
|
KL = (t1 + t2 - 0.5).mean() |
KLD appears to use variance in place of standard deviation. utils.var() computes variance as squared distance from mean. Then it's squared again in the KLN01Loss module. Should it be (in the default 'qp' direction):
t1 = samples_var + samples_mean.pow(2)
t2 = -samples_var.log()
KL = (t1 + t2 - 1).mean()/2
?
(Additionally, the paper gives the KLD as a sum but here it's a mean, changing the meaning of the hyperparameters weighting the reconstruction losses)
AGE/src/losses.py
Lines 38 to 41 in 0915760
KLD appears to use variance in place of standard deviation. utils.var() computes variance as squared distance from mean. Then it's squared again in the KLN01Loss module. Should it be (in the default 'qp' direction):
?
(Additionally, the paper gives the KLD as a sum but here it's a mean, changing the meaning of the hyperparameters weighting the reconstruction losses)