logsoftmax(dim=1)是对批量样本中的每个样本取概率,而不是针对每个特征

博客指出logsoftmax(dim=1)是针对样本样本中的每个样本取概率,而非针对每个特征。这明确了logsoftmax在特定维度下的作用对象,对理解其功能有重要意义。

class BatchContrastiveLoss(nn.Module): def __init__(self, temperature=0.1, neg_ratio=5): super().__init__() self.temperature = temperature self.neg_ratio = neg_ratio # 正负样本比例 1:5 def forward(self, embeddings, pos_pairs, neg_pairs): """ 改进的批量对比损失实现 参数: pos_pairs: [(anchor_idx, pos_idx)] 长度N neg_pairs: [(anchor_idx, [neg_idx1, neg_idx2,...])] 长度N,每个元素包含K个负样本 """ # 构造批量数据 anchors = [a for a, _ in pos_pairs] positives = [p for _, p in pos_pairs] negatives = [negs for _, negs in neg_pairs] # 转换为张量 anchors = torch.tensor(anchors, device=embeddings.device) # [N] positives = torch.tensor(positives, device=embeddings.device) # [N] negatives = torch.stack([torch.tensor(negs, device=embeddings.device) for negs in negatives]) # [N, K] # 获嵌入向量 anchor_emb = embeddings[anchors] # [N, D] positive_emb = embeddings[positives] # [N, D] negative_emb = embeddings[negatives] # [N, K, D] # 计算正样本相似度 pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=-1) # [N] # 计算负样本相似度 anchor_expanded = anchor_emb.unsqueeze(1) # [N, 1, D] neg_sim = F.cosine_similarity(anchor_expanded, negative_emb, dim=-1) # [N, K] # 组合logits logits = torch.cat([ pos_sim.unsqueeze(1), # [N, 1] neg_sim # [N, K] ], dim=1) / self.temperature # [N, 1+K] # 目标标签(正样本位置为0) labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device) return F.cross_entropy(logits, labels) 记住这段损失函数
03-25
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值