基于Pytorch的CapsNet源码详解

本文详细解析基于Pytorch的CapsNet源码,包括CapsNet基本结构、胶囊相关组件(Squash激活函数、PrimaryCaps和DigitCaps层)以及网络结构和代价函数。介绍了动态路由算法和胶囊网络如何通过卷积和激活函数实现信息的压缩和传递。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文由部分公式,因简书不支持公式渲染,公式完整版请移步个人博客

CapsNet基本结构

参考CapsNet的论文,提出的基本结构如下所示:

7241055-7c32f43163b2a0c8.jpg
capsnet_mnist.jpg

可以看出,CapsNet的基本结构如下所示:

  • 普通卷积层Conv1:基本的卷积层,感受野较大,达到了9x9
  • 预胶囊层PrimaryCaps:为胶囊层准备,运算为卷积运算,最终输出为[batch,caps_num,caps_length]的三维数据:
    • batch为批大小
    • caps_num为胶囊的数量
    • caps_length为每个胶囊的长度(每个胶囊为一个向量,该向量包括caps_length个分量)
  • 胶囊层DigitCaps:胶囊层,目的是代替最后一层全连接层,输出为10个胶囊

代码实现

胶囊相关组件

激活函数Squash

胶囊网络有特有的激活函数Squash函数:
$$
Squash(S) = \cfrac{||S||2}{1+||S||2} \cdot \cfrac{S}{||S||}
$$
其中输入为S胶囊,该激活函数可以将胶囊的长度压缩,代码实现如下:

def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs

其中:

  • norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)计算输入胶囊的长度,p=2表示计算的是二范数,keepdim=True表示保持原有的空间形状。
  • scale = norm**2 / (1 + norm**2) / (norm + 1e-8)计算缩放因子,即$ \cfrac{||S||2}{1+||S||2} \cdot \cfrac{1}{||S||}$
  • return scale * inputs完成计算

预胶囊层PrimaryCaps

class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: i
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值