SINABS使用

本文介绍了瑞士类脑芯片公司aiCTX开源的SNN仿真库SINABS。先阐述了Spiking Neurons相关知识,包括人工神经元模型、LIF和IAF模型及激活函数;接着介绍SINABS提供的例程,如MNIST例程中CNN转SCNN,以及使用BPTT训练SNN的方法。

SINABS是瑞士类脑芯片公司aiCTX开源的SNN仿真库,https://sinabs.ai对该库进行了详细的介绍。

1.Spiking Neurons简介

1.1人工神经元(Artificial Neuron)模型

人工神经元模型可以表示为y=f(Wx+b),其中f为非线性激活函数。

人工神经元的的输出直接依赖于输入,神经元没有影响到其输出的内部状态;而脉冲神经网络SNN增加了神经元对当前状态的额外依赖。

1.2 LIF(Leaky Integrate and Fire)模型

LIF模型是一个简单的脉冲神经元(Spiking Neuron)模型,其中神经元的状态通常为膜电位v,取决于输入和先前的状态,神经元模型可以表述如下:

其中tau为膜时间常数,定义了神经元对先前状态和输入的依赖程度;该模型描述了具有泄露动力学的系统,膜电位v随着时间推移缓慢泄露为0,也是LIF(Leaky Integrate and Fire)模型的由来。公式中Isyn=Wx,是所有输入突触的加权和,Ibias为输入偏置。R表示常数电阻,用于膜电位和电流电位之间的匹配。

神经元的输出为二值和瞬时的,,当v达到阈值vth时,神经元输出才输出1,同时膜电位立即重置为vreset(该值小于阈值vth)。因此v可以表示为:

1被称为时间t的尖峰,一系列的尖峰被称为s(t):

膜电位随时间演变的轨迹如下:

1.3 IAF(Constant-leak Integrate and Fire)模型

LIF模型描述了一个动态系统,v是不断更新的,为了便于计算,假设膜电位只有一个恒定的泄漏,为IAF模型,表述如下:

vleak为常数,可以与偏差结合,上式可以简化为:

IAF膜电位随着时间演变的轨迹为:

1.4 激活函数

对于脉冲神经网络来说,定义激活函数较为复杂,因为脉冲神经网络的输出与输入x和自身的状态相关,神经元的输出为一系列尖峰和狄拉克三角函数系列。通过观察神经元在一个时间窗口内产生的尖峰数量可以解释神经元对输入的反应,这样以神经元的峰值速率对峰值进行解释的方法称为速率编码。

基于对脉冲神经元的速率编码解释,IAF神经元模型的传递函数展示如下:

因此IAF神经元的传递函数等价于ReLU激活函数。

2.SINABS提供的例程

2.1 MNIST

该例程主要对如何建立一个脉冲神经网络进行介绍:

  • 使用Pytorch定义一个卷积神经网络,具有三个卷积层,一个全连接层和一个输出层
import torch.nn as nn

ann = nn.Sequential(
    nn.Conv2d(1, 20, 5, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Conv2d(20, 32, 5, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Conv2d(32, 128, 3, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Flatten(),
    nn.Linear(128, 500, bias=False),
    nn.ReLU(),
    nn.Linear(500, 10, bias=False),
)
  • 加载数据(例程目的是为了模拟一个SNN网络,网络返回一系列脉冲)
    import numpy as np
    from PIL import Image
    from torchvision import datasets
    
    class MNIST_Dataset(datasets.MNIST):
    
        def __init__(self, root, train = True, spiking=False, tWindow=100):
            datasets.MNIST.__init__(self, root, train=train, download=True)
            self.spiking=spiking
            self.tWindow = tWindow
    
    
        def __getitem__(self, index):
            img, target = self.data[index], self.targets[index]
    
            if self.spiking:
                img = (np.random.rand(self.tWindow, 1, *img.size()) < img.numpy()/255.0).astype(float)
                img = torch.from_numpy(img).float()
            else:
                # Convert PIL image to tensor
                img = torch.from_numpy(img.numpy()).float()
                img.unsqueeze_(0)
    
            return img, target
    

     

  • 训练网络,其中加载train_loader时设置spiking=False,即使用费脉冲化的常规图像训练模型,使用Adam优化器、学习率为0.0001,损失函数为交叉熵损失;该例程仅训练了3个epochs以作示意。

from torch.utils.data import DataLoader

# Define test dataset loader
train_loader = DataLoader(
    MNIST_Dataset('./data', train=True, spiking=False),
    batch_size=128, shuffle=True)

import tqdm
import torch
import torch.nn.functional as F
import torch.optim as optim


try:
    # Load a pre-trained model to save time if you have already have one.
    ann.load_state_dict(torch.load("mnist_params.pt"))
except:
    # Train the model

    ann.train()

    optim = torch.optim.Adam(ann.parameters(), lr=1e-4)

    n_epochs = 3

    for n in tqdm.notebook.tqdm(range(n_epochs)):
        pbar = tqdm.notebook.tqdm(train_loader)
        # Iterate over data
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = ann(data)
            optim.zero_grad()

            # Add loss to the total loss
            loss = F.cross_entropy(output, target)

            # Propagate loss backwards
            loss.backward()

            # Update weights
            optim.step()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            # Compute the total correct predictions
            correct = pred.eq(target.view_as(pred)).sum().item()

            pbar.set_postfix({"loss": loss.item(), "accuracy": correct/(len(target))})

    # Save model parameters
    torch.save(ann.state_dict(), "mnist_params.pt")
  • 下面介绍如何将CNN结构转换为SCNN结构,sinabs中的from_model函数可以将CNN转换为SCNN
from sinabs.from_torch import from_model

input_shape = (1, 28, 28)

sinabs_model = from_model(ann, input_shape=input_shape, add_spiking_output=True)

  • 对sinabs_model进行可视化
sinabs_model.spiking_model

  • 进行一个简单的测试,其中数据集加载函数中的spiking应当设置为True,其中tWindow是一个十分重要的参数,应当多次进行调试
# Time window per sample
tWindow = 200 # ms (or) time steps

# Define test dataset loader
test_spike_loader = torch.utils.data.DataLoader(
    MNIST_Dataset('./data', train=False, spiking=True, tWindow=tWindow),
    batch_size=1, shuffle=False)

test(sinabs_model, test_spike_loader, num_batches=200)
  • sinabs集成了计算突触操作总数的方法:
sinabs_model.get_synops()

2.2 使用BPTT进行训练

BPTT一般在训练卷积神经网络中使用,与普通的神经网络不同,SNN的内部状态会持续一段时间,即使网络不是循环出现的,仍然可以通过膜电位的持续性记忆其先前的处理步骤。

  • 定义数据集,其中像素值介于0-1之间,将这些值转化为峰值概率
from torchvision import datasets
import torch

torch.manual_seed(0)

class MNIST_Dataset(datasets.MNIST):
    def __init__(self, root, train=True, single_channel=False):
        datasets.MNIST.__init__(self, root, train=train, download=True)
        self.single_channel = single_channel

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = img.float() / 255.

        # default is  by row, output is [time, channels] = [28, 28]
        # OR if we want by single item, output is [784, 1]
        if self.single_channel:
            img = img.reshape(-1).unsqueeze(1)

        spikes = torch.rand(size=img.shape) < img
        spikes = spikes.float()

        return spikes, target

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 64

dataset_test = MNIST_Dataset(root="./data/", train=False)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=BATCH_SIZE, drop_last=True)

dataset = MNIST_Dataset(root="./data/", train=True)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, drop_last=True)
  • 训练一个baseline
from torch import nn

ann = nn.Sequential(
    nn.Linear(28, 128),
    nn.ReLU(),
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.ReLU()
)
from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())

for epoch in range(5):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()

        target = target.unsqueeze(1).repeat([1, 28])
        img = img.reshape([-1, 28])
        target = target.reshape([-1])

        out = ann(img)
#         out = out.sum(1)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())
  • 定义一个SNN并训练,网络状态必须在每次迭代时重置
from sinabs.from_torch import from_model

model = from_model(ann, batch_size=BATCH_SIZE).to(device)
model = model.train()

from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()
        model.reset_states()

        out = model.spiking_model(img.to(device))
        # the output of the network is summed over the 28 time steps (rows)
        out = out.sum(1)
        loss = criterion(out, target.to(device))
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())
  • 测试
accs = []

pbar = tqdm(dataloader_test)
for img, target in pbar:
    model.reset_states()

    out = model(img.to(device))
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target.to(device)).sum().cpu().numpy() / BATCH_SIZE
    accs.append(acc)

print(sum(accs)/len(accs))

 

评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

space_dandy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值