目录
1.pytorch_ssim文件
1.1 gaussian函数
def gaussian(window_size, sigma):
'''
形成大小为window_size,标准差为sigma的一串高斯序列
:param window_size:
:param sigma:
:return:
'''
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
1.2 create_window函数
def create_window(window_size, channel):
'''
:param window_size: 矩阵大小
:param channel: 通道大小
:return: (channel,1,window_size,window_size)大小的高斯滤波器
'''
# .unsqueeze(1)是插入第二个维度
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
# _1D_window.mm矩阵乘法
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
# .contiguous()的作用是返回一个具有相同数据但存储方式不同的新张量。
# 这通常用于确保张量在内存中是连续存储的,以便后续的计算操作。
# 将处理过的二维高斯滤波器包装在Variable对象中,并将其作为函数的返回值。
# 在较新版本的PyTorch中,Variable对象已经被整合到了torch.Tensor中,所以在实际应用中可能不再需要显式地使用Variable。
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
这两部分可以看做生成一个高斯滤波器的代码, 至于什么是高斯滤波器, 它有什么作用, 参考这篇博客. 简而言之就是利用卷积运算, 对图像进行平滑处理, 减少高频部分, 保留低频部分. 由于高频部分(细节)部分消失, 图像就会变得模糊. 高斯滤波器对于抑制 高斯噪声
(服从正态分布的噪声) 非常有效. 这里应该是考虑到图像生成过程中, 或多或少会有噪声, 所以加入高斯滤波器减少噪声. 也可能是减少细节部分, 对两个图片进行综合的评价(因为细节会拉大评价差距?)
1.3 SSIM评价指标函数
def _ssim(img1, img2, window, window_size, channel, size_average = True):
'''
这段代码实现了SSIM(结构相似度)指标的计算,用于度量两幅图像之间的结构相似性。
:param img1:图像1,大小(batch_size, channel, height, width)
:param img2:图像2,大小(batch_size, channel, height, width)
:param window::高斯滤波器,大小(channel, 1, window_size, window_size)
:param window_size:高斯滤波器大小 window_size
:param channel:通道数
:param size_average:一个布尔值,表示是否对SSIM值进行平均。
如果为True,则返回所有像素的均值作为SSIM值;如果为False,则返回每张图像的SSIM值的均值
:return:
'''
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
什么是SSIM指标参考这篇---博客
这里再补充一点, 对于 size_average 参数, 举个例子, 若比较的图像为img1 = (10,3,5,5), img2 =(10,3,5,5)的大小,选定size_average = true, 则返回这10个图片的SSIM值的平均值; 若选定size_average = false, 则返回这10个图片对应比较的SSIM值, 即为一个大小为10的序列.
除此之外, mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
这里group = channel是将img1分为group个通道, 再与window做卷积(因为window = (channel, 1, window_size, window_size)). 这部分卷积的维度大小一定要弄明白, 可以多尝试几次.
1.4 SSIM评价指标类
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
'''
:param window_size: 高斯滤波器大小
:param size_average: 是否返回所有图片SSIM指标
'''
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
# 得到img1的通道
(_, channel, _, _) = img1.size()
# 若图片的通道和滤波器的通道一致,并且图片和滤波器的数据类型一致,则保留window
# 否则重新创建一个和图片通道一致的window
# 修改window的设备,数据类型
# 最后返回img1,img2的SSIM指标
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
# 和class SSIM相似,不过多解释
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
2. 总结
这部分文件主要是对img1, img2计算SSIM指标, 感觉比较简单, 主要是把SSIM的计算公式用python代码实现即可.