老板,来一斤BUG!——Pytorch

本文详细解析PyTorch中数据类型的默认设置及其重要性,包括权重的float32格式,以及进行wx+b运算时数据类型的一致性需求。同时,介绍了CrossEntropyLoss在N分类任务中的应用,强调了网络输出与标签格式的要求。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

2019/2/28:

torch.Tensor默认的tensor类型(torch.FloatTensor)的简称。也就是float、float32。

pytorch的权重weight,若未手动修改,默认格式一定是float32,不会随输入的格式而自适应更改。

问题在于,在pytorch中,要进行wx+b,w与x的数据格式要一致!!!

最后,如果要进行反向传播更新参数,那么权重weight这类参数的数据格式一定不能是整型int!!!

2019/3/1:

torch.nn.CrossEntropyLoss()用于N分类,要求:

网络最后的全连接层输出N个特征图,且不需要加激活层。

标签 target 应为 LongTensor类型(BCELoss要求为DoubleTensor)。

如果使用 torch.utils.data.TensorDataset 与 torch.utils.data.DataLoader,那么 0<= label[index]<=N-1。标签为1D,不需要one-hot。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值