One-hot的使用

文章介绍了在PyTorch中如何使用`nn.functional.one_hot`函数解决损失函数输入数据尺寸不一致的问题,特别是当标签数据缺少通道数时。它强调了one_hot函数要求输入为整型,并展示了如何将浮点型数据转换为整型,以及在one_hot操作后将结果转换回浮点型以适应模型的bias。最后,提到将one_hot处理后的数据用于计算损失函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch官方已经提供了具体one-hot函数,可以直接使用

from torch.nn.functional import one_hot

result = one_hot(data,num_class)

首先介绍下,one_hot可以运用在当调用损失函数时,所放入数据尺寸不等的时候

例如:

    raise ValueError("Target size ({}) must be the same as input size    ({})".format(target.size(), input.size()))


ValueError: Target size (torch.Size([2, 256, 256])) must be the same as input size (torch.Size([2, 2, 256, 256]))

可以看到上面输入的预测是一个[2, 2, 256, 256] 的tensor(分别是nchw,即batch_size, channels, height, width)

而label的则是[2, 256, 256] ,缺少了通道数(因为label是单通道)

因此可以使用one_hot进行转换,将tensor数据直接放入one_hot中的第一个参数的位置,因为只有2类,所以num_class=2

from torch.nn.functional import one_hot

result = one_hot(label,2)

但直接放有时会遇到数据类型的问题,由于one_hot只接受int类型,所以若label数据是float类型,需要先将类型转换为int

label = label.to(dtype=torch.int64)

然后再放进one_hot

最后,由于模型的bias比较常为float类型,所以直接把输出的one_hot放进去会报数据类型错误

,因此在one_hot输出后,还需要再转换为float类型

one_hot_output = one_hot_output.type(torch.float32)

最后,把one_hot输出和pred放进计算loss的函数进行计算

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值