关于nn.upsample在GPU上无法兼容BFloat16的问题

在CNN中,nn.upsample常用于上采样操作,尤其是最近大火的扩散模型中,UNet网络的上采样均是采用这个操作执行该任务。尽管如此,nn.upsample在GPU上运行时,与torch.bfloat16会发生冲突,常给出这样的错误:RuntimeError:“upsample_nearest2d_out_frame” not implemented for ‘BFloat16’,从而导致很多高性能计算受阻。torch.bfloat16数据格式,是指"Brain Floating Point"格式占位16位,由Google Brain发明,专门为TPU研制,这种格式有很多优越的性能(详见https://cloud.google.com/tpu/docs/bfloat16?hl=zh-cn);后面人们发现这种数据格式,在GPU框架下的训练速度很快,同时对性能影响很小。如Lightning库(https://lightning.ai/)专门为Pytorch加速时,常使用这种数据格式,我们尝试过,使用这种数据格式训练扩散模型,每迭代1000次,要比其他数据格式快10s左右(在3090上)。因此,这个数据格式nn.upsample这个类在GPU上计算不兼容,将极大地影响学习进程。注意:只是在GPU上,会冲突,在CPU上不会冲突。即:

import torch.nn as nn
data=torch.rand(4,3,8,8,dtype=torch.bfloat16)
up = nn.Upsample(scale_factor=2.0, mode="nearest")
output_up = up(data)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值