下面代码仅适用于二分类算法
import torch
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights = bin_label_weights * valid_mask
return bin_labels, bin_label_weights, valid_mask
def transfer_label_2_pred(pred, label, ignore_index=255, weight=None):
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 2, H, W] and that of label is [N, H, W].
# As the ignore_index often set as 255, so the
# binary class label check should mask out
# ignore_index
assert label[label != ignore_index].max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze(1)
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
return label, weight, valid_mask
else:
return label
本文介绍如何将适用于二分类算法的标签尺寸从(8, 256, 256)转换为(8, 2, 256, 256),涉及深度学习和人工智能中的数据预处理步骤。"
126072160,14658339,Java 19新特性:虚拟线程与协程的崛起,"['Java', '架构', '并发编程', '预览版特性', '操作系统移植']
3012

被折叠的 条评论
为什么被折叠?



