本文由部分公式,因简书不支持公式渲染,公式完整版请移步个人博客
CapsNet基本结构
参考CapsNet的论文,提出的基本结构如下所示:

capsnet_mnist.jpg
可以看出,CapsNet的基本结构如下所示:
- 普通卷积层Conv1:基本的卷积层,感受野较大,达到了9x9
- 预胶囊层PrimaryCaps:为胶囊层准备,运算为卷积运算,最终输出为[batch,caps_num,caps_length]的三维数据:
- batch为批大小
- caps_num为胶囊的数量
- caps_length为每个胶囊的长度(每个胶囊为一个向量,该向量包括caps_length个分量)
- 胶囊层DigitCaps:胶囊层,目的是代替最后一层全连接层,输出为10个胶囊
代码实现
胶囊相关组件
激活函数Squash
胶囊网络有特有的激活函数Squash函数:
$$
Squash(S) = \cfrac{||S||2}{1+||S||2} \cdot \cfrac{S}{||S||}
$$
其中输入为S胶囊,该激活函数可以将胶囊的长度压缩,代码实现如下:
def squash(inputs, axis=-1):
norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
return scale * inputs
其中:
-
norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
计算输入胶囊的长度,p=2
表示计算的是二范数,keepdim=True
表示保持原有的空间形状。 -
scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
计算缩放因子,即$ \cfrac{||S||2}{1+||S||2} \cdot \cfrac{1}{||S||}$ -
return scale * inputs
完成计算
预胶囊层PrimaryCaps
class PrimaryCapsule(nn.Module):
"""
Apply Conv2D with `out_channels` and then reshape to get capsules
:param in_channels: i