我这样做
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from enum import Enum
from torch.nn.parameter import Parameter
# 论文题目:QUANTIZED SPIKE-DRIVEN TRANSFORMER
# 论文链接:https://arxiv.org/pdf/2501.13492
# 官方github: https://github.com/bollossom/QSD-Transformer/blob/main/classification/quan_w.py
# 代码改进者:一勺汤
class ReLUX(nn.Module):
def __init__(self, thre=8):
super(ReLUX, self).__init__()
self.thre = thre
def forward(self, input):
return torch.clamp(input, 0, self.thre)
relu4 = ReLUX(thre=4)
class multispike(torch.autograd.Function):
@staticmethod
def forward(ctx, input, lens):
ctx.save_for_backward(input)
ctx.lens = lens
return torch.floor(relu4(input) + 0.5)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp1 = 0 < input
temp2 = input < ctx.lens
return grad_input * temp1.float() * temp2.float(), None
class Multispike(nn.Module):
def __init__(self, lens=4, spike=multispike):
super().__init__()
self.lens = lens
self.spike = spike
def forward(self, inputs):
return self.spike.apply(4 * inputs, self.lens) / 4
def grad_scale(x, scale):
y = x
y_grad = x * scale
return y.detach() - y_grad.detach() + y_grad
def round_pass(x):
y = x.round()
y_grad = x
return y.detach() - y_grad.detach() + y_grad
class Qmodes(Enum):
layer_wise = 1
kernel_wise = 2
class _LinearQ(nn.Linear):
def __init__(self, in_features, out_features, bias=True, **kwargs_q):
#print(in_features, out_features)
super(_LinearQ, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
self.nbits = kwargs_q['nbits']
if self.nbits < 0:
self.register_parameter('alpha', None)
return
self.q_mode = kwargs_q['mode']
self.alpha = Parameter(torch.Tensor(1))
if self.q_mode == Qmodes.kernel_wise:
self.alpha = Parameter(torch.Tensor(out_features))
self.register_buffer('init_state', torch.zeros(1))
def add_param(self, param_k, param_v):
self.kwargs_q[param_k] = param_v
def extra_repr(self):
s_prefix = super(_LinearQ, self).extra_repr()
if self.alpha is None:
return '{}, fake'.format(s_prefix)
return '{}, {}'.format(s_prefix, self.kwargs_q)
class _ActQ(nn.Module):
def __init__(self, in_features, **kwargs_q):
super(_ActQ, self).__init__()
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
self.nbits = kwargs_q['nbits']
if self.nbits < 0:
self.register_parameter('alpha', None)
self.register_parameter('zero_point', None)
return
# self.signed = kwargs_q['signed']
self.q_mode = kwargs_q['mode']
self.alpha = Parameter(torch.Tensor(1))
self.zero_point = Parameter(torch.Tensor([0]))
if self.q_mode == Qmodes.kernel_wise:
self.alpha = Parameter(torch.Tensor(in_features))
self.zero_point = Parameter(torch.Tensor(in_features))
torch.nn.init.zeros_(self.zero_point)
# self.zero_point = Parameter(torch.Tensor([0]))
self.register_buffer('init_state', torch.zeros(1))
self.register_buffer('signed', torch.zeros(1))
def add_param(self, param_k, param_v):
self.kwargs_q[param_k] = param_v
def set_bit(self, nbits):
self.kwargs_q['nbits'] = nbits
def extra_repr(self):
# s_prefix = super(_ActQ, self).extra_repr()
if self.alpha is None:
return 'fake'
return '{}'.format(self.kwargs_q)
def get_default_kwargs_q(kwargs_q, layer_type):
default = {
'nbits': 4
}
if isinstance(layer_type, _Conv2dQ):
default.update({
'mode': Qmodes.layer_wise})
elif isinstance(layer_type, _LinearQ):
pass
elif isinstance(layer_type, _ActQ):
pass
# default.update({
# 'signed': 'Auto'})
else:
assert NotImplementedError
return
for k, v in default.items():
if k not in kwargs_q:
kwargs_q[k] = v
return kwargs_q
class _Conv2dQ(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, **kwargs_q):
super(_Conv2dQ, self).__init__(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
self.nbits = kwargs_q['nbits']
if self.nbits < 0:
self.register_parameter('alpha', None)
return
self.q_mode = kwargs_q['mode']
if self.q_mode == Qmodes.kernel_wise:
self.alpha = Parameter(torch.Tensor(out_channels))
else: # layer-wise quantization
self.alpha = Parameter(torch.Tensor(1))
self.register_buffer('init_state', torch.zeros(1))
def add_param(self, param_k, param_v):
self.kwargs_q[param_k] = param_v
def set_bit(self, nbits):
self.kwargs_q['nbits'] = nbits
def extra_repr(self):
s_prefix = super(_Conv2dQ, self).extra_repr()
if self.alpha is None:
return '{}, fake'.format(s_prefix)
return '{}, {}'.format(s_prefix, self.kwargs_q)
class ActLSQ(_ActQ):
def __init__(self, in_features, nbits_a=4, mode=Qmodes.kernel_wise, **kwargs):
super(ActLSQ, self).__init__(in_features=in_features, nbits=nbits_a, mode=mode)
# print(self.alpha.shape, self.zero_point.shape)
def forward(self, x):
return x
class Conv2dLSQ(_Conv2dQ):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, nbits_w=4, mode=Qmodes.kernel_wise, **kwargs):
super(Conv2dLSQ, self).__init__(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias,
nbits=nbits_w, mode=mode)
self.act = ActLSQ(in_features=in_channels, nbits_a=nbits_w)
def forward(self, x):
if self.alpha is None:
return F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# w_reshape = self.weight.reshape([self.weight.shape[0], -1]).transpose(0, 1)
Qn = -2 ** (self.nbits - 1)
Qp = 2 ** (self.nbits - 1) - 1
if self.training and self.init_state == 0:
# self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1))
self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
# self.alpha.data.copy_(self.weight.abs().max() * 2)
self.init_state.fill_(1)
"""
Implementation according to paper.
Feels wrong ...
When we initialize the alpha as a big number (e.g., self.weight.abs().max() * 2),
the clamp function can be skipped.
Then we get w_q = w / alpha * alpha = w, and $\frac{\partial w_q}{\partial \alpha} = 0$
As a result, I don't think the pseudo-code in the paper echoes the formula.
Please see jupyter/STE_LSQ.ipynb fo detailed comparison.
"""
g = 1.0 / math.sqrt(self.weight.numel() * Qp)
# Method1: 31GB GPU memory (AlexNet w4a4 bs 2048) 17min/epoch
alpha = grad_scale(self.alpha, g)
# print(alpha.shape)
# print(self.weight.shape)
alpha = alpha.unsqueeze(1).unsqueeze(2).unsqueeze(3)
w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
x = self.act(x)
# w = w.clamp(Qn, Qp)
# q_w = round_pass(w)
# w_q = q_w * alpha
# Method2: 25GB GPU memory (AlexNet w4a4 bs 2048) 32min/epoch
# w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
# wq = y.transpose(0, 1).reshape(self.weight.shape).detach() + self.weight - self.weight.detach()
return F.conv2d(x, w_q, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class BNAndPadLayer(nn.Module):
def __init__(
self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
):
super(BNAndPadLayer, self).__init__()
self.bn = nn.BatchNorm2d(
num_features, eps, momentum, affine, track_running_stats
)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
if self.bn.affine:
pad_values = (
self.bn.bias.detach()
- self.bn.running_mean
* self.bn.weight.detach()
/ torch.sqrt(self.bn.running_var + self.bn.eps)
)
else:
pad_values = -self.bn.running_mean / torch.sqrt(
self.bn.running_var + self.bn.eps
)
output = F.pad(output, [self.pad_pixels] * 4)
pad_values = pad_values.view(1, -1, 1, 1)
output[:, :, 0: self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels:, :] = pad_values
output[:, :, :, 0: self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels:] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
class RepConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
bias=False,
):
super().__init__()
# hidden_channel = in_channel
conv1x1 = Conv2dLSQ(in_channel, in_channel, 1, 1, 0, bias=False, groups=1)
bn = BNAndPadLayer(pad_pixels=1, num_features=in_channel)
conv3x3 = nn.Sequential(
Conv2dLSQ(in_channel, in_channel, 3, 1, 0, groups=in_channel, bias=False),
Conv2dLSQ(in_channel, out_channel, 1, 1, 0, groups=1, bias=False),
nn.BatchNorm2d(out_channel),
)
self.body = nn.Sequential(conv1x1, bn, conv3x3)
def forward(self, x):
return self.body(x)
class Multispike_att(nn.Module):
def __init__(self, lens=4, spike=multispike):
super().__init__()
self.lens = lens
self.spike = spike
def forward(self, inputs):
return self.spike.apply(4 * inputs, self.lens) / 2
class MS_Attention_RepConv_qkv_id(nn.Module):
def __init__(
self,
dim,
num_heads=8,
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.scale = 0.25
self.head_lif = Multispike()
self.q_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim))
self.k_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim))
self.v_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim))
self.q_lif = Multispike()
self.k_lif = Multispike()
self.v_lif = Multispike()
self.attn_lif = Multispike_att()
self.proj_conv = nn.Sequential(
RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)
)
def forward(self, x):
x = x.unsqueeze(0)
T, B, C, H, W = x.shape
N = H * W
x = self.head_lif(x)
q = self.q_conv(x.flatten(0, 1)).reshape(T, B, C, H, W)
k = self.k_conv(x.flatten(0, 1)).reshape(T, B, C, H, W)
v = self.v_conv(x.flatten(0, 1)).reshape(T, B, C, H, W)
q = self.q_lif(q).flatten(3)
q = (
q.transpose(-1, -2)
.reshape(T, B, N, self.num_heads, C // self.num_heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
k = self.k_lif(k).flatten(3)
k = (
k.transpose(-1, -2)
.reshape(T, B, N, self.num_heads, C // self.num_heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
v = self.v_lif(v).flatten(3)
v = (
v.transpose(-1, -2)
.reshape(T, B, N, self.num_heads, C // self.num_heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
x = k.transpose(-2, -1) @ v
x = (q @ x) * self.scale
x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x).reshape(T, B, C, H, W)
x = x.reshape(T, B, C, H, W)
x = x.flatten(0, 1)
x = self.proj_conv(x).reshape(T, B, C, H, W)
x = x.squeeze(0)
return x
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
class PSABloc_MSAR(nn.Module):
"""
PSABlock class implementing a Position-Sensitive Attention block for neural networks.
This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
with optional shortcut connections.
Attributes:
attn (Attention): Multi-head attention module.
ffn (nn.Sequential): Feed-forward neural network module.
add (bool): Flag indicating whether to add shortcut connections.
Methods:
forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
Examples:
Create a PSABlock and perform a forward pass
>>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)
>>> input_tensor = torch.randn(1, 128, 32, 32)
>>> output_tensor = psablock(input_tensor)
"""
def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
"""Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction."""
super().__init__()
self.attn = MS_Attention_RepConv_qkv_id(dim=c, num_heads=num_heads)
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
self.add = shortcut
def forward(self, x):
"""Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor."""
x = x + self.attn(x) if self.add else self.attn(x)
x = x + self.ffn(x) if self.add else self.ffn(x)
return x
class C2PSA_MSAR(nn.Module):
"""
C2PSA module with attention mechanism for enhanced feature extraction and processing.
This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
Attributes:
c (int): Number of hidden channels.
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
Methods:
forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
Notes:
This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
Examples:
>>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)
>>> input_tensor = torch.randn(1, 256, 64, 64)
>>> output_tensor = c2psa(input_tensor)
"""
def __init__(self, c1, c2, n=1, e=0.5):
"""Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio."""
super().__init__()
assert c1 == c2
self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1)
self.m = nn.Sequential(*(PSABloc_MSAR(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
def forward(self, x):
"""Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor."""
a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = self.m(b)
return self.cv2(torch.cat((a, b), 1))
def main():
# 设置随机种子以确保结果可重复
torch.manual_seed(42)
# 定义输入张量 (批次大小 B=2, 通道数 C=64, 高度 H=16, 宽度 W=16)
B, C, H, W = 2, 64, 7, 16
x = torch.randn(B, C, H, W) # 随机生成输入张量
# 初始化 MS_Attention_RepConv_qkv_id 模块
dim = C # 输入通道数
num_heads = 8 # 多头注意力机制的头数
attention_module = MS_Attention_RepConv_qkv_id(dim=dim, num_heads=num_heads)
# 打印输入张量的形状
print("Input shape:", x.shape)
# 前向传播
output = attention_module(x)
# 打印输出张量的形状
print("Output shape:", output.shape)
# 打印输出张量的最小值和最大值
print("Output min value:", output.min().item())
print("Output max value:", output.max().item())
if __name__ == "__main__":
main()