记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文为学习第八章风格迁移的学习笔记。
主要分析实现代码里面main,transformer,PackedVGG,utils这4个代码文件完成整个项目模型结构定义,训练及生成,还有输出展示的整个过程。
utils
这是项目中用到的一些额外工具包,对visdom封装更方便显示多个点以及网格显示图片这个老朋友就不多说了。此外还有一个图像里面常用的BN操作(注意,这里的BN时IBN既每个样本一个标准化,而不是整批的标准化,而且BN的均值标准差不是一般的0.5时专人计算过的imagenet上的均值和标准差)的函数以及获取封装的从路径获取风格图片的函数,还有一个在这种风格迁移任务里用到的计算Gram矩阵的函数。
Gram矩阵
Gram是用来描述图像整体风格特征的一种矩阵。假设输入的图像尺寸是BxCxHxW,那么得到的Gram矩阵的尺寸是BxCxC既只于输入的通道数有关。矩阵取值由如下公式计算
其中 F i k F_{ik} Fik代表第i个输入通道(实际上输入的是多个特征图,所以就是第i个特征图)的第k个像素,同时可以发现该计算和位置信息无关,也就是说原始图片随机打散后生成新的特征图重新计算Gram矩阵也是一致的。说明Gram提取是整体风格而不是结构信息。代码实现计算Gram矩阵如下。
def gram_matrix(y):
"""
Input shape: b,c,h,w
Output shape: b,c,c
"""
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_