对应内容【https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.14/activation_based/container.html】
import torch
from spikingjelly.activation_based import neuron, functional, layer
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]
net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]
# z_seq is identical to y_seq
[这段代码演示了如何在SpikingJelly框架中使用不同的包装器来处理脉冲神经网络(SNN)的输入序列。SpikingJelly是一个专注于脉冲神经网络的深度学习框架,提供了多种工具和模型来模拟和训练SNN。下面是对代码中使用的关键概念和组件的详细解释:
函数风格的multi_step_forward和模块风格的MultiStepContainer
multi_step_forward函数:这是一个函数风格的工具,它允许你将一个单步模块(例如,一个单个时间步长的SNN层或节点)应用于一个输入序列的多个时间步长。这意味着你可以将整个时间序列一次性传递给这个函数,而函数内部会处理序列中每个时间步长的前向传播。
MultiStepContainer模块:这是一个模块风格的包装器,它可以将任何单步模块包装成一个多步模块。这样,当你将一个时间序列作为输入传递给这个包装后的模块时,它会自动处理整个序列,而不需要显式地在代码中循环每个时间步长。
函数风格的seq_to_ann_forward和模块风格的SeqToANNContainer
这些工具和包装器用于将脉冲神经网络(SNN)的输出转换为传统的人工神经网络(ANN)的输入。这在混合网络架构中非常有用,其中SNN部分处理时间序列数据,而ANN部分负责进一步的数据处理或决策制定。
对单步模块进行包装以进行单步/多步传播的StepModeContainer
StepModeContainer:这是一个更通用的包装器,它允许你控制包装的模块是以单步模式还是多步模式运行。这为在不同场景下灵活使用模块提供了便利。
代码解释
在提供的代码片段中,首先使用neuron.IFNode创建了一个积分并发放(Integrate-and-Fire, IF)神经元节点,这是SNN中最基本的神经元模型之一。然后,演示了如何使用multi_step_forward函数和MultiStepContainer模块来处理一个时间序列的输入。]
multi_step_forward函数 处理的是序列经过网络 多步传输
MultiStepContainer模块 处理的是网络,然后用网络来通过序列
【对于无状态的ANN网络层,例如 torch.nn.Conv2d,其本身要求输入数据的 shape = [N, *],若用于多步模式,则可以用多步的包装器进行包装:】
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
【但是ANN的网络层本身是无状态的,不存在前序依赖,没有必要在时间上串行的计算,可以使用函数风格的 seq_to_ann_forward 或模块风格的 SeqToANNContainer 进行包装。seq_to_ann_forward 将 shape = [T, N, *] 的数据首先变换为 shape = [TN, *],再送入无状态的网络层进行计算,输出的结果会被重新变换为 shape = [T, N, *]。不同时刻的数据是并行计算的,因而速度更快:】
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
# p_seq.shape = [T, N, 8, H, W]
net = layer.SeqToANNContainer(conv, bn)
q_seq = net(x_seq)
# q_seq.shape = [T, N, 8, H, W]
# q_seq is identical to p_seq, and also identical to y_seq and z_seq
【这两段代码展示了如何将SNN的时间序列输入转换为适用于传统人工神经网络(ANN)的形式。functional.seq_to_ann_forward是一个函数风格的工具,而SeqToANNContainer提供了一个模块风格的包装器。两者都实现了相同的功能,即处理时间序列输入并将其转换为ANN可以处理的形式。
总结
这段代码通过不同的方法展示了如何在SpikingJelly中处理SNN的时间序列输入。它使用了函数风格的multi_step_forward和seq_to_ann_forward,以及模块风格的MultiStepContainer和SeqToANNContainer来实现这一目的。这些工具和包装器使得在PyTorch框架下构建和模拟SNN变得更加灵活和方便。】
【常用的网络层,在 spikingjelly.activation_based.layer 已经定义过,更推荐使用 spikingjelly.activation_based.layer 中的网络层,而不是使用 SeqToANNContainer 手动包装,尽管 spikingjelly.activation_based.layer 中的网络层实际上就是用包装器包装 forward 函数实现的。spikingjelly.activation_based.layer 中的网络层,优势在于:
支持单步和多步模式,而 SeqToANNContainer 和 MultiStepContainer 包装的层,只支持多步模式
包装器会使得 state_dict 的 keys() 也增加一层包装,给加载权重带来麻烦】
最好不要用函数风格的 seq_to_ann_forward 和模块风格的 SeqToANNContainer
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron
ann = nn.Sequential(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU()
)
print(f'ann.state_dict.keys()={ann.state_dict().keys()}')
net_container = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
),
neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')
net_origin = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')
try:
print('net_container is trying to load state dict from ann...')
net_container.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_container can not load! The error message is\n', e)
try:
print('net_origin is trying to load state dict from ann...')
net_origin.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_origin can not load! The error message is', e)
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!
用SeqToANNContainer 构建的网络,不能把值加载进去,因为0.0 和0 不一样,有区别
【MultiStepContainer 和 SeqToANNContainer 都是只支持多步模式的,不允许切换为单步模式。
StepModeContainer 类似于融合版的 MultiStepContainer 和 SeqToANNContainer,可以用于包装无状态或有状态的单步模块,需要在包装时指明是否有状态,但此包装器还支持切换单步和多步模式。】
无状态的ann,net = layer.StepModeContainer(
import torch
from spikingjelly.activation_based import neuron, layer
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
False,
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
有snn的 net = layer.StepModeContainer(
True,
neuron.IFNode()
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
net.step_mode=m
net[0].step_mode=s
[如果模块本身就支持单步和多步模式的切换,则不推荐使用 MultiStepContainer 或 StepModeContainer 对其进行包装。因为包装器使用的多步前向传播,可能不如模块自身定义的前向传播速度快。]
[通常需要用到 MultiStepContainer 或 StepModeContainer 的是一些没有定义多步的模块,例如一个在 torch.nn 中存在,但在 spikingjelly.activation_based.layer 中不存在的网络层。]
特殊的层,使用