今天计算损失函数,需要将cosine的值小于0的置0:
# torch.cosine_similarity(le, re, dim=1) 是我的Tensor
torch.where(torch.cosine_similarity(le, re, dim=1) < 0, 0, torch.cosine_similarity(le, re, dim=1))
通用模板如下,第一个参数是控制条件,这里是小于0,第二个元素是满足条件的赋值,第三个是不满足条件的赋值,这里不小于0则保留原来的数值。
# 张量a
torch.where(a < 0, 0, a)
同理,不小于0的可以赋值为1。
# 张量a
torch.where(a < 0, 0, 1)