超分辨率中为什么不适用BN

本文探讨了BatchNorm在深度学习中的重要作用,尤其是在图像分类任务中的优势,同时揭示了在图像超分辨率和风格转移中为何效果不佳。重点分析了BatchNorm如何影响色彩对比度并指出其在特定网络结构(如ResidualNet)中的例外情况。

Batch Norm可谓深度学习中非常重要的技术,不仅可以使训练更深的网络变容易,加速收敛,还有一定正则化的效果,可以防止模型过拟合。在很多基于CNN的分类任务中,被大量使用。
但在图像超分辨率和图像生成方面,Batch Norm的表现并不好,加入了Batch Norm,反而使得训练速度缓慢,不稳定,甚至最后发散。

以图像超分辨率来说,网络输出的图像在色彩、对比度、亮度上要求和输入一致,改变的仅仅是分辨率和一些细节,而Batch Norm,对图像来说类似于一种对比度的拉伸,任何图像经过Batch Norm后,其色彩的分布都会被归一化,也就是说,它破坏了图像原本的对比度信息,所以Batch Norm的加入反而影响了网络输出的质量。虽然Batch Norm中的scale和shift参数可以抵消归一化的效果,但这样就增加了训练的难度和时间,还不如直接不用。不过有一类网络结构可以用,那就是残差网络(Residual Net),但也仅仅是在residual block当中使用,比如SRResNet,就是一个用于图像超分辨率的残差网络。为什么这类网络可以使用Batch Norm呢?有人认为是因为图像的对比度信息可以通过skip connection直接传递,所以也就不必担心Batch Norm的破坏了。

基于这种想法,也可以从另外一种角度解释Batch Norm为何在图像分类任务上如此有效。图像分类不需要保留图像的对比度信息,利用图像的结构信息就可以完成分类,所以,将图像都通过Batch Norm进行归一化,反而降低了训练难度,甚至一些不明显的结构,在Batch Norm后也会被凸显出来(对比度被拉开了)。

而对于照片风格转移,为何可以用Batch Norm呢?原因在于,风格化后的图像,其色彩、对比度、亮度均和原图像无关,而只与风格图像有关,原图像只有结构信息被表现到了最后生成的图像中。因此,在照片风格转移的网络中使用Batch Norm或者Instance Norm也就不奇怪了,而且,Instance Norm是比Batch Norm更直接的对单幅图像进行的归一化操作,连scale和shift都没有。

说得更广泛一些,Batch Norm会忽略图像像素(或者特征)之间的绝对差异(因为均值归零,方差归一),而只考虑相对差异,所以在不需要绝对差异的任务中(比如分类),有锦上添花的效果。而对于图像超分辨率这种需要利用绝对差异的任务,Batch Norm只会添乱。

原文链接:https://blog.youkuaiyun.com/prinstinadl/article/details/80835088

以下是一个示例程序,将ResNet更改为接受低辨率图像列表的输入,然后输出一张高辨率图像。 ```python import torch import torch.nn as nn import torch.nn.functional as F class ResidualBlock(nn.Module): def __init__(self, in_channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(in_channels) self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(in_channels) def forward(self, x): identity = x out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += identity out = F.relu(out) return out class UpsampleBlock(nn.Module): def __init__(self, in_channels, out_channels): super(UpsampleBlock, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.shuf = nn.PixelShuffle(2) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): out = self.conv(x) out = self.shuf(out) out = F.relu(self.bn(out)) return out class ResNetSR(nn.Module): def __init__(self, num_blocks=16, input_channels=3, output_channels=3): super(ResNetSR, self).__init__() self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=9, stride=1, padding=4, bias=False) self.bn1 = nn.BatchNorm2d(64) self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_blocks)]) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64) self.upsample = nn.Sequential( UpsampleBlock(64, 256), UpsampleBlock(64, 256), nn.Conv2d(64, output_channels, kernel_size=9, stride=1, padding=4, bias=False) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.res_blocks(out) out = F.relu(self.bn2(self.conv2(out))) out = self.upsample(out) return out ``` 在这个例子中,我们定义了三个子模块: - `ResidualBlock`是ResNet的基本块,用于提取图像的特征。 - `UpsampleBlock`是用于上采样特征的块。 - `ResNetSR`是我们的超分辨率模型,它由`ResidualBlock`和`UpsampleBlock`组成。 在`ResNetSR`中,我们首先使用一个大的卷积层(`conv1`)来提取特征,然后使用多个`ResidualBlock`来增加深度。之后,我们使用另一个卷积层(`conv2`)来进一步提取特征,并使用多个`UpsampleBlock`来上采样特征。最后,我们使用一个大的卷积层(`upsample`)来生成高辨率图像输出。 这个模型接受一个低辨率图像列表作为输入,并输出一张高辨率图像。在训练和测试期间,我们需要将低辨率图像列表堆叠在一起,以便将它们输入到模型中。例如: ```python lr_image_list = [lr_image1, lr_image2, lr_image3] lr_images = torch.stack(lr_image_list, dim=0) model = ResNetSR() sr_image = model(lr_images) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值