参考文档
JSD实现代码
若有纰漏,敬请指出,感谢!
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='batchmean')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output )/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
一些注意事项
-
关于dlv函数的使用:
函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log
-
JS 散度度量了两个概率分布的相似度,基于KL散度的变体,解决了KL散度非对称的问题。所以jsd(q, p)与jsd(p, q)一致。