spikingjelly学习-包装器

对应内容【https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.14/activation_based/container.html】

import torch
from spikingjelly.activation_based import neuron, functional, layer

net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]

net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]

# z_seq is identical to y_seq

[这段代码演示了如何在SpikingJelly框架中使用不同的包装器来处理脉冲神经网络(SNN)的输入序列。SpikingJelly是一个专注于脉冲神经网络的深度学习框架,提供了多种工具和模型来模拟和训练SNN。下面是对代码中使用的关键概念和组件的详细解释:
函数风格的multi_step_forward和模块风格的MultiStepContainer
multi_step_forward函数:这是一个函数风格的工具,它允许你将一个单步模块(例如,一个单个时间步长的SNN层或节点)应用于一个输入序列的多个时间步长。这意味着你可以将整个时间序列一次性传递给这个函数,而函数内部会处理序列中每个时间步长的前向传播。
MultiStepContainer模块:这是一个模块风格的包装器,它可以将任何单步模块包装成一个多步模块。这样,当你将一个时间序列作为输入传递给这个包装后的模块时,它会自动处理整个序列,而不需要显式地在代码中循环每个时间步长。
函数风格的seq_to_ann_forward和模块风格的SeqToANNContainer
这些工具和包装器用于将脉冲神经网络(SNN)的输出转换为传统的人工神经网络(ANN)的输入。这在混合网络架构中非常有用,其中SNN部分处理时间序列数据,而ANN部分负责进一步的数据处理或决策制定。
对单步模块进行包装以进行单步/多步传播的StepModeContainer
StepModeContainer:这是一个更通用的包装器,它允许你控制包装的模块是以单步模式还是多步模式运行。这为在不同场景下灵活使用模块提供了便利。
代码解释
在提供的代码片段中,首先使用neuron.IFNode创建了一个积分并发放(Integrate-and-Fire, IF)神经元节点,这是SNN中最基本的神经元模型之一。然后,演示了如何使用multi_step_forward函数和MultiStepContainer模块来处理一个时间序列的输入。]

multi_step_forward函数 处理的是序列经过网络 多步传输
MultiStepContainer模块 处理的是网络,然后用网络来通过序列

在这里插入图片描述
在这里插入图片描述

【对于无状态的ANN网络层,例如 torch.nn.Conv2d,其本身要求输入数据的 shape = [N, *],若用于多步模式,则可以用多步的包装器进行包装:】

import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer

with torch.no_grad():
    T = 4
    N = 1
    C = 3
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])

    conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
    bn = nn.BatchNorm2d(8)

    y_seq = functional.multi_step_forward(x_seq, (conv, bn))
    # y_seq.shape = [T, N, 8, H, W]

    net = layer.MultiStepContainer(conv, bn)
    z_seq = net(x_seq)
    # z_seq.shape = [T, N, 8, H, W]

    # z_seq is identical to y_seq

在这里插入图片描述
【但是ANN的网络层本身是无状态的,不存在前序依赖,没有必要在时间上串行的计算,可以使用函数风格的 seq_to_ann_forward 或模块风格的 SeqToANNContainer 进行包装。seq_to_ann_forward 将 shape = [T, N, *] 的数据首先变换为 shape = [TN, *],再送入无状态的网络层进行计算,输出的结果会被重新变换为 shape = [T, N, *]。不同时刻的数据是并行计算的,因而速度更快:】

import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer

with torch.no_grad():
    T = 4
    N = 1
    C = 3
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])

    conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
    bn = nn.BatchNorm2d(8)

    y_seq = functional.multi_step_forward(x_seq, (conv, bn))
    # y_seq.shape = [T, N, 8, H, W]

    net = layer.MultiStepContainer(conv, bn)
    z_seq = net(x_seq)
    # z_seq.shape = [T, N, 8, H, W]

    # z_seq is identical to y_seq

    p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
    # p_seq.shape = [T, N, 8, H, W]

    net = layer.SeqToANNContainer(conv, bn)
    q_seq = net(x_seq)
    # q_seq.shape = [T, N, 8, H, W]

    # q_seq is identical to p_seq, and also identical to y_seq and z_seq

【这两段代码展示了如何将SNN的时间序列输入转换为适用于传统人工神经网络(ANN)的形式。functional.seq_to_ann_forward是一个函数风格的工具,而SeqToANNContainer提供了一个模块风格的包装器。两者都实现了相同的功能,即处理时间序列输入并将其转换为ANN可以处理的形式。
总结
这段代码通过不同的方法展示了如何在SpikingJelly中处理SNN的时间序列输入。它使用了函数风格的multi_step_forward和seq_to_ann_forward,以及模块风格的MultiStepContainer和SeqToANNContainer来实现这一目的。这些工具和包装器使得在PyTorch框架下构建和模拟SNN变得更加灵活和方便。】
在这里插入图片描述
【常用的网络层,在 spikingjelly.activation_based.layer 已经定义过,更推荐使用 spikingjelly.activation_based.layer 中的网络层,而不是使用 SeqToANNContainer 手动包装,尽管 spikingjelly.activation_based.layer 中的网络层实际上就是用包装器包装 forward 函数实现的。spikingjelly.activation_based.layer 中的网络层,优势在于:

支持单步和多步模式,而 SeqToANNContainer 和 MultiStepContainer 包装的层,只支持多步模式

包装器会使得 state_dict 的 keys() 也增加一层包装,给加载权重带来麻烦】

最好不要用函数风格的 seq_to_ann_forward 和模块风格的 SeqToANNContainer

import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron


ann = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(8),
    nn.ReLU()
)

print(f'ann.state_dict.keys()={ann.state_dict().keys()}')

net_container = nn.Sequential(
    layer.SeqToANNContainer(
        nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(8),
    ),
    neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')

net_origin = nn.Sequential(
    layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(8),
    neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')

try:
    print('net_container is trying to load state dict from ann...')
    net_container.load_state_dict(ann.state_dict())
    print('Load success!')
except BaseException as e:
    print('net_container can not load! The error message is\n', e)

try:
    print('net_origin is trying to load state dict from ann...')
    net_origin.load_state_dict(ann.state_dict())
    print('Load success!')
except BaseException as e:
    print('net_origin can not load! The error message is', e)
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
    Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
    Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!

用SeqToANNContainer 构建的网络,不能把值加载进去,因为0.0 和0 不一样,有区别

在这里插入图片描述
【MultiStepContainer 和 SeqToANNContainer 都是只支持多步模式的,不允许切换为单步模式。

StepModeContainer 类似于融合版的 MultiStepContainer 和 SeqToANNContainer,可以用于包装无状态或有状态的单步模块,需要在包装时指明是否有状态,但此包装器还支持切换单步和多步模式。】

无状态的ann,net = layer.StepModeContainer(

import torch
from spikingjelly.activation_based import neuron, layer


with torch.no_grad():
    T = 4
    N = 2
    C = 4
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])
    net = layer.StepModeContainer(
        False,
        nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(C),
    )
    net.step_mode = 'm'
    y_seq = net(x_seq)
    # y_seq.shape = [T, N, C, H, W]

    net.step_mode = 's'
    y = net(x_seq[0])
    # y.shape = [N, C, H, W]

有snn的 net = layer.StepModeContainer(
True,
neuron.IFNode()

import torch
from spikingjelly.activation_based import neuron, layer, functional


with torch.no_grad():
    T = 4
    N = 2
    C = 4
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])
    net = layer.StepModeContainer(
        True,
        neuron.IFNode()
    )
    net.step_mode = 'm'
    y_seq = net(x_seq)
    # y_seq.shape = [T, N, C, H, W]
    functional.reset_net(net)

    net.step_mode = 's'
    y = net(x_seq[0])
    # y.shape = [N, C, H, W]
net.step_mode=m
net[0].step_mode=s

[如果模块本身就支持单步和多步模式的切换,则不推荐使用 MultiStepContainer 或 StepModeContainer 对其进行包装。因为包装器使用的多步前向传播,可能不如模块自身定义的前向传播速度快。]

[通常需要用到 MultiStepContainer 或 StepModeContainer 的是一些没有定义多步的模块,例如一个在 torch.nn 中存在,但在 spikingjelly.activation_based.layer 中不存在的网络层。]
特殊的层,使用

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

weixin_44781508

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值