MMD(p,q) > or = 0 ,当p=q的时候才会为0
当时使用pytorch计算的时候可能会出现 为 nan 的情况
这是因为 p和q矩阵中每个元素都是相同的
# 基于pytorch
import torch
from torch.autograd import Variable
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand

文章讨论了在使用PyTorch计算最大均值差异(MMD)时可能会遇到的NaN问题,特别是在p和q矩阵元素相同时。解决方案包括检查数据的线性无关性和调整学习率。代码示例展示了如何计算高斯核以及MMD损失,并提出了在元素全相等时可能出现的问题。
最低0.47元/天 解锁文章
2702





