Pytorch 类别标签转换one-hot编码

本文详细介绍了PyTorch中scatter_函数的使用方法,包括如何将索引标签转换为one-hot编码,并展示了在分类任务和图像分割中的应用实例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这里用到了Pytorch的scatter_函数:

scatter_(dim, index, src) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
For a 3-D tensor, self is updated as:

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

注意要保证self和index维度一致

对于分类问题,标签可以是类别索引值也可以是one-hot表示。以10类别分类为例,lable=[3] 和label=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]是一致的.

>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
 3
 0
 0
 8

>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
    0     0     0     1     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     1     0

自己做图像分割时,把图像标签转成one-hot编码形式,原图像(1,512,512),生成one-hot(class_nums, 512, 512):

gt_onehot = torch.zeros((class_nums, gt.shape[1], gt.shape[2]))
gt_onehot.scatter_(0, gt.long(), 1)

参考:
1.https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_
2.PyTorch——Tensor_把索引标签转换成one-hot标签表示

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值