本章目的:用图卷积神经网络实现离散数据的分类 ( 以图像分类为例 ) .
5.1 卷积计算过程
在实际项目中,输入神经网络的是具有更高分辨率的彩色图片,使得送入全连接网络的输入特征数过多,随着隐藏层层数的增加,网络规模过大,待优化参数过多,很容易使模型过拟合. 为了减少待训练参数,在实际应用时,会先对原始图片进行特征提取,把提取来的特征送入全连接网络,让其输出识别结果.
卷积计算是一种有效的特征提取方法,一般会用一个正方形的卷积核,按指定步长在输入特征图上滑动,遍历这种输入特征图中的每个像素点,每滑动一个步长,卷积核会与输入特征图部分像素点重合,重合区域对应元素相乘、求和再加上偏置项,得到输出特征的一个像素点.
如果输入特征是单通道灰度图,使用深度为 1 的单通道卷积核;如果输入特征是三通道彩色图,使用 3*3*3 的卷积核或 5*5*3 的卷积核. 总之,要使卷积核的通道数与输入特征图的通道数一致. 这是因为,如果想要让卷积核与输入特征图对应点匹配上,必须让卷积核的深度与输入特征图的深度一致. 所以,输入特征图的深度决定了当前层卷积核的深度.
由于每个卷积核在卷积计算后会得到一个输出特征图,所以当前层使用了几个卷积核,就会有几张输出特征图. 所以,当前层卷积核的个数,决定了当前层输出特征图的深度.
卷积核立体图如上图所示,每个小颗粒都存储了一个待训练参数,在执行卷积计算时,卷积核中的参数是固定的,在每次反向传播时,这些待训练参数会被梯度下降法更新. 卷积就是利用立体卷积核实现了参数的空间共享.
- 具体计算过程:
对于输入特征图是单通道的,选择单通道卷积核,上图为 5 行 5 列单通道,选用 3*3 单通道卷积核,滑动步长为 1 . 在输入特征图上滑动,每滑动一步输入特征图与卷积核里的 9 个元素重合,对应元素相乘求和再加上偏置项 b ,如卷积核滑动到图上位置时,计算为:(-1)*1+0*0+1*2+(-1)*5+0*4+1*2+(-1)*3+0*4+1*5+1 = 1.
对于输入特征图是三通道的,选择三通道卷积核,计算过程如上图.
5.2 感受野 ( Receptive Field )
感受野指输出特征图 ( feature map ) 中的一个像素点映射到原始输入图片的区域大小.
如上图所示,原始输入图片大小为 5 * 5 ,用黄色的 3 * 3 卷积核作用,则会输出一个 3 * 3 的输出特征图,输出特征图上每个像素点映射到原始输入图片是 3 * 3 的区域,因此,它的感受野是 3 ;如果再对这个 3 * 3 的特征图用绿色的 3 * 3 卷积核作用,会输出一个 1 * 1 的输出特征图,那么,这个输出特征图上的像素点映射到原始图片是 5 * 5 的区域,因此,它的感受野是 5 .
如上图所示,对原始图片直接用蓝色的 5 * 5 卷积核作用,则会输出一个 1 * 1 的输出特征图,它的感受野为 5 .
对于同样一个 5 * 5 的原始图片,经过两层 3 * 3 卷积核作用和经过一层 5 * 5 卷积核作用都可以得到一个感受野是 5 的 1 * 1 输出特征图,因此,两层 3 * 3 卷积核和一层 5 * 5 卷积核的特征提取能力是一样的,但如何做选择,就需要考虑他们所承载的待训练参数和计算量. 假设输入特征宽和高都为 x ,卷积计算步长均为 1 ,具体计算如下:
- 对于两层 3 * 3 卷积核:
参数量:9 + 9 = 18
计算量:每个 3 * 3 卷积核计算得到一个输出像素点需要做 9 次乘加计算,两层则需要:
- 对于一层 5 * 5 卷积核:
参数量:25
计算量:每个 5 * 5 卷积核计算得到一个输出像素点需要做 25 次乘加计算,一层则需要:
因此,可以比较出,当输入特征图边长大于 10 个像素点时,两层 3 * 3 卷积核要比一层 5 * 5 卷积核性能好.
5.3 全零填充 ( Padding )
当希望卷积计算保持输入特征图的尺寸不变,可以使用全零填充,在输入特征图周围填充 0 .
卷积输出特征图维度的计算公式为:
其中,入长为输入特征图边长.
# 在 tf 中, 描述全零填充
padding = 'SAME' 或 padding = 'VALID'
5.4 TF 描述卷积计算层
- Tensorflow 给出了计算卷积的函数:
tf.keras.layers.Conv2D(
filters = 卷积核个数,
kenerl_size = 卷积核尺寸, # 正方形写核长整数, 或 ( 核高 h ,核宽 w )
strides = 滑动步长, # 横纵向相同写步长整数或 ( 纵向步长 h, 横向步长 w ), 默认 1
padding = 'same' or 'valid' # 默认是 valid
activation = 'relu' or 'sigmoid' or 'tanh' of 'softmax' ... # 如有 BN ( batch normalization ) 此处不写
input_shape =