AMN关键代码详解

Threshold Matters in WSSS: Manipulating the Activation for the Robust and Accurate Segmentation Model Against Thresholds

train_amn.py

logit = model(img, label_cls)

B, C, H, W = logit.shape

label_amn = resize_labels(label_amn.cpu(), size=logit.shape[-2:]).cuda()
# 将类别标签 label_amn 调整为与 logit 的预测输出大小相同,保证类别标签和预测输出匹配。
label_ = label_amn.clone()
label_[label_amn == 255] = 0
# 处理无效类别标签或者边界标签
given_labels = torch.full(size=(B, C, H, W), fill_value=args.eps/(C-1)).cuda()
# 创建一个与 logit 相同大小的张量,其中每个元素填充为 args.eps/(C-1)。这个张量将在下一步中用于生成目标标签
given_labels.scatter_(dim=1, index=torch.unsqueeze(label_, dim=1), value=1-args.eps)
# 在 dim=1 维度上使用 label_ 的值,在 given_labels 张量中将相应的位置设置为 1-args.eps,以生成目标标签。
# 这实际上是为了在 given_labels 中设置与真实类别对应的位置为 1,其他位置为 1-args.eps。
loss_pcl = balanced_cross_entropy(logit, label_amn, given_labels)
# 计算平衡的交叉熵损失
loss = loss_pcl
loss.backward()

涉及的调用函数

def balanced_cross_entropy(logits, labels, one_hot_labels):
    """
    :param logits: shape: (N, C)
    :param labels: shape: (N, C)
    :param reduction: options: "none", "mean", "sum"
    :return: loss or losses
    """

    N, C, H, W = logits.shape

    assert one_hot_labels.size(0) == N and one_hot_labels.size(1) == C, f'label tensor shape is {one_hot_labels.shape}, while logits tensor shape is {logits.shape}'

    log_logits = F.log_softmax(logits, dim=1)
    loss_structure = -torch.sum(log_logits * one_hot_labels, dim=1)  # (N)
	# 相应位置的 one_hot_labels 与 log_softmax 进行点积得到每个样本的损失。
    ignore_mask_bg = torch.zeros_like(labels)
    ignore_mask_fg = torch.zeros_like(labels)
    
    ignore_mask_bg[labels == 0] = 1  # 忽略背景掩码
    ignore_mask_fg[(labels != 0) & (labels != 255)] = 1 # 忽略前景类别
    
    loss_bg = (loss_structure * ignore_mask_bg).sum() / ignore_mask_bg.sum()
    loss_fg = (loss_structure * ignore_mask_fg).sum() / ignore_mask_fg.sum()

    return (loss_bg+loss_fg)/2


def resize_labels(labels, size):
    """
    Downsample labels for 0.5x and 0.75x logits by nearest interpolation.
    Other nearest methods result in misaligned labels.
    -> F.interpolate(labels, shape, mode='nearest')
    -> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST)
    """
    new_labels = []
    for label in labels:
        label = label.float().numpy()
        label = Image.fromarray(label).resize(size, resample=Image.NEAREST)
        new_labels.append(np.asarray(label))
    new_labels = torch.LongTensor(new_labels)
    return new_labels
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Env1sage

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值