论文核心
论文提出了非结构化剪枝策略,针对卷积层权重进行剪枝,并提出了著名的三步走剪枝策略。判断权重重要性的方式是对权重进行
L
1
L1
L1或
L
2
L2
L2正则化,然后按照一定的剪枝比例使正则化值较小的权重为0。
论文细节品读
模型压缩意义:论文从功耗方面讨论了模型需要被压缩的原因,神经网络越大,参数量和计算量越大,导致模型在前向推理时功耗越大,这对移动端和嵌入式端是残酷的。并且大模型无法存储在SRAM 中(由于工艺原因,SRAM 一般只有几M 到几十M ),只能存储在DRAM 中,而从DRAM 中读取模型数据时,功耗会是SRAM 的100多倍。
正则化方式:论文提出了
L
1
L1
L1和
L
2
L2
L2正则化两种方式来判断权重重要性,经过实验,
L
2
L2
L2正则化效果更好。
神经元剪枝:经过权重修剪后,修剪部分权重为0,因此可将对应的神经元、梯度同时剪枝(置零)。论文有提到这一点,但是看论文作者提供的源码,没有进行这一步操作(也许是代码版本原因,该部分可自行拓展)。
论文复现
一:选取数据集、模型、优化器等
本人在此选择 cifar10 数据集,模型选择 resnet18 ,其他细节略过。(作者源码都有此部分的配置,可阅读源码)
二:模型训练
在作者论文基础上添加了可视化部分代码
def paint_vinz(epoch):
global test_loss, train_loss, train_top1, test_top1, train_error_history, test_error_history
viz.line(X=epoch_p,
Y=np.column_stack((np.array(train_error_history), np.array(test_error_history))),
win=line, opts=dict(legend=['train_error', 'test_error']))
# visdom text 支持html语句
viz.text(
"<p style='color:red'>epoch: {}</p><br>"
"<p style='color:BlueViolet'>train_loss: {:.4f}</p><br>"
"<p style='color:blue'>test_loss: {:.4f}</p><br>"
"<p style='color:orange'>train_top1: {:.4f}</p><br>"
"<p style='color:'black'>test_top1: {:.4f}</p><br>"
"<p style='color:green'>Time: {}</p>".format(
epoch+1, train_loss, test_loss, train_top1, test_top1,
time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))), win=text)
三:未剪枝模型params & FLOPs 计算
这里使用的 torhcstat 包用来统计模型相关参数,贴上个人代码。论文作者在源码中统计的不是模型的参数量,而是非零参数,这一点后面讨论。
# 注意该部分代码也要配合论文作者源码,才能正常使用
import os
import torch
from utils import*
from models import get_model
import argparse
from torchstat import stat
parser = argparse.ArgumentParser(description="model params counting test")
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
args = parser.parse_args()
model = get_model('resnet18')
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['net'], strict=False)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
stat(model, (3, 32, 32))
```python
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv_bn_relu.conv 3 32 32 64 32 32 1728.0 0.25 3,473,408.0 1,769,472.0 19200.0 262144.0 0.00% 281344.0
1 conv_bn_relu.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
2 conv_bn_relu.relu 64 32 32 64 32 32 0.0 0.25 65,536.0 65,536.0 262144.0 262144.0 0.00% 524288.0
3 layer1.0.conv1.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
4 layer1.0.conv1.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
5 layer1.0.conv1.relu 64 32 32 64 32 32 0.0 0.25 65,536.0 65,536.0 262144.0 262144.0 0.00% 524288.0
6 layer1.0.conv2.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
7 layer1.0.conv2.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
8 layer1.0.conv2.relu 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
9 layer1.0.shortcut 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
10 layer1.1.conv1.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
11 layer1.1.conv1.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
12 layer1.1.conv1.relu 64 32 32 64 32 32 0.0 0.25 65,536.0 65,536.0 262144.0 262144.0 0.00% 524288.0
13 layer1.1.conv2.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
14 layer1.1.conv2.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
15 layer1.1.conv2.relu 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
16 layer1.1.shortcut 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
17 layer2.0.conv1.conv 64 32 32 128 16 16 73728.0 0.12 37,715,968.0 18,874,368.0 557056.0 131072.0 0.00% 688128.0
18 layer2.0.conv1.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
19 layer2.0.conv1.relu 128 16 16 128 16 16 0.0 0.12 32,768.0 32,768.0 131072.0 131072.0 0.00% 262144.0
20 layer2.0.conv2.conv 128 16 16 128 16 16 147456.0 0.12 75,464,704.0 37,748,736.0 720896.0 131072.0 0.00% 851968.0
21 layer2.0.conv2.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
22 layer2.0.conv2.relu 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
23 layer2.0.shortcut.conv_bn.conv 64 32 32 128 16 16 8192.0 0.12 4,161,536.0 2,097,152.0 294912.0 131072.0 0.00% 425984.0
24 layer2.0.shortcut.conv_bn.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
25 layer2.0.shortcut.conv_bn.relu 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
26 layer2.1.conv1.conv 128 16 16 128 16 16 147456.0 0.12 75,464,704.0 37,748,736.0 720896.0 131072.0 0.00% 851968.0
27 layer2.1.conv1.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
28 layer2.1.conv1.relu 128 16 16 128 16 16 0.0 0.12 32,768.0 32,768.0 131072.0 131072.0 0.00% 262144.0
29 layer2.1.conv2.conv 128 16 16 128 16 16 147456.0 0.12 75,464,704.0 37,748,736.0 720896.0 131072.0 0.00% 851968.0
30 layer2.1.conv2.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
31 layer2.1.conv2.relu 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
32 layer2.1.shortcut 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
33 layer3.0.conv1.conv 128 16 16 256 8 8 294912.0 0.06 37,732,352.0 18,874,368.0 1310720.0 65536.0 0.00% 1376256.0
34 layer3.0.conv1.bn 256 8 8 256 8 8 512.0 0.06 65,536.0 32,768.0 67584.0 65536.0 0.00% 133120.0
35 layer3.0.conv1.relu 256 8 8 256 8 8 0.0 0.06 16,384.0 16,384.0 65536.0 65536.0 0.00% 131072.0
36 layer3.0.conv2.conv 256 8 8 256 8 8 589824.0 0.06 75,481,088.0 37,748,736.0 2424832.0 65536.0 0.00% 2490368.0
37 layer3.0.conv2.bn 256 8 8 256 8 8 512.0 0.06 65,536.0 32,768.0 67584.0 65536.0 0.00% 133120.0
38 layer3.0.conv2.relu 256 8 8 256 8 8 0.0 0.06 0.0 0.0 0.0 0.0 0.00% 0.0
39 layer3.0.shortcut.conv_bn.conv 128 16 16 256 8 8 32768.0 0.06 4,177,920.0 2,097,152.0 262144.0 65536.0 0.00% 327680.0
40 layer3.0.shortcut.conv_bn.bn 256 8 8 256 8 8 512.0 0.06 65,536.0 32,768.0 67584.0 65536.0 0.00% 133120.0
41 layer3.0.shortcut.conv_bn.relu 256 8 8 256 8 8 0.0 0.06 0.0 0.0 0.0 0.0 0.00% 0.0
42 layer3.1.conv1.conv 256 8 8 256 8 8 589824.0 0.06 75,481,088.0 37,748,736.0 2424832.0 65536.0 0.00% 2490368.0
43 layer3.1.conv1.bn 256 8 8 256 8 8 512.0 0.06 65,536.0 32,768.0 67584.0 65536.0 0.00% 133120.0
44 layer3.1.conv1.relu 256 8 8 256 8 8 0.0 0.06 16,384.0 16,384.0 65536.0 65536.0 0.00% 131072.0
45 layer3.1.conv2.conv 256 8 8 256 8 8 589824.0 0.06 75,481,088.0 37,748,736.0 2424832.0 65536.0 0.00% 2490368.0
46 layer3.1.conv2.bn 256 8 8 256 8 8 512.0 0.06 65,536.0 32,768.0 67584.0 65536.0 0.00% 133120.0
47 layer3.1.conv2.relu 256 8 8 256 8 8 0.0 0.06 0.0 0.0 0.0 0.0 0.00% 0.0
48 layer3.1.shortcut 256 8 8 256 8 8 0.0 0.06 0.0 0.0 0.0 0.0 0.00% 0.0
49 layer4.0.conv1.conv 256 8 8 512 4 4 1179648.0 0.03 37,740,544.0 18,874,368.0 4784128.0 32768.0 0.00% 4816896.0
50 layer4.0.conv1.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
51 layer4.0.conv1.relu 512 4 4 512 4 4 0.0 0.03 8,192.0 8,192.0 32768.0 32768.0 0.00% 65536.0
52 layer4.0.conv2.conv 512 4 4 512 4 4 2359296.0 0.03 75,489,280.0 37,748,736.0 9469952.0 32768.0 0.00% 9502720.0
53 layer4.0.conv2.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
54 layer4.0.conv2.relu 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
55 layer4.0.shortcut.conv_bn.conv 256 8 8 512 4 4 131072.0 0.03 4,186,112.0 2,097,152.0 589824.0 32768.0 0.00% 622592.0
56 layer4.0.shortcut.conv_bn.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
57 layer4.0.shortcut.conv_bn.relu 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
58 layer4.1.conv1.conv 512 4 4 512 4 4 2359296.0 0.03 75,489,280.0 37,748,736.0 9469952.0 32768.0 0.00% 9502720.0
59 layer4.1.conv1.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
60 layer4.1.conv1.relu 512 4 4 512 4 4 0.0 0.03 8,192.0 8,192.0 32768.0 32768.0 0.00% 65536.0
61 layer4.1.conv2.conv 512 4 4 512 4 4 2359296.0 0.03 75,489,280.0 37,748,736.0 9469952.0 32768.0 0.00% 9502720.0
62 layer4.1.conv2.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
63 layer4.1.conv2.relu 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
64 layer4.1.shortcut 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
65 linear 512 10 5130.0 0.00 10,230.0 5,120.0 22568.0 40.0 0.00% 22608.0
total 11173962.0 7.75 1,112,999,926.0 556,962,816.0 22568.0 40.0 0.00% 57227600.0
=======================================================================================================================================================================
Total params: 11,173,962
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 7.75MB
Total MAdd: 1.11GMAdd
Total Flops: 556.96MFlops
Total MemR+W: 54.58MB
使用论文作者代码计算卷积层非零参数个数,代码如下:
# count only conv params for now
def get_no_params(net, verbose=False, mask=False):
params = net
tot = 0
for p in params:
no = torch.sum(params[p] != 0)
if "conv" in p:
tot += no
return tot
实验结果如下:
tensor(11178452)
论文作者虽然只是统计了卷积层参数量,发现结果和使用 torchstat 几乎相同,这是因为 resnet 卷积层参数占绝大部分比重,可以通过上图各层参数量统计结果看出原因所在。此时还未剪枝,所有参数均不为零,所以卷积层非零参数量即为卷积层总参数量。
四:模型权重剪枝
论文作者源码提供的是使用L1正则化的策略,unstructured_prune函数,并非直接修改模型 state_dict ,而是生成 mask ,也就是说,模型经过该函数处理后,只是更新了模型的 mask 。阅读源码发现模型权重实质改变发生在在前向推理过程中。
def unstructured_prune(self, model, prune_rate=50.0):
# get all the prunable convolutions
convs = model.get_prunable_layers(pruning_type=self.pruning_type)
# collate all weights into a single vector so l1-threshold can be calculated
all_weights = torch.Tensor()
if torch.cuda.is_available():
all_weights = all_weights.cuda()
for conv in convs:
# 将所有层的weight拼接起来
all_weights = torch.cat((all_weights.view(-1), conv.conv.weight.view(-1)))
abs_weights = torch.abs(all_weights.detach())
# 找出百分位数,设置阈值
threshold = np.percentile(abs_weights.cpu().numpy(), prune_rate)
# prune anything beneath l1-threshold
for conv in model.get_prunable_layers(pruning_type=self.pruning_type):
conv.mask.update(
# 逐乘
torch.mul(
torch.gt(torch.abs(conv.conv.weight), torch.as_tensor(np.array(threshold).astype('float')).cuda()).float(),
conv.mask.mask.weight.cuda(),
)
)
剪枝完之后就是微调-finetune操作,即每修剪一次,就进行100个 batch 的训练。每次微调之后权重为0的参数会再训练,得到更新后不再为0。
def finetune(model, trainloader, criterion, optimizer, steps=100):
# switch to train mode
model.train()
dataiter = iter(trainloader)
for i in range(steps):
try:
input, target = dataiter.next()
except StopIteration:
dataiter = iter(trainloader)
input, target = dataiter.next()
input, target = input.to(device), target.to(device)
# compute output
output = model(input)
loss = criterion(output, target)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
测试微调之后权重为0的参数变化情况:
pruner.prune(model, prune_rate)
print("训练前=================={}".format(get_no_params(model.state_dict())))
finetune(model, trainloader, criterion, optimizer, args.finetune_steps)
print("训练后=================={}".format(get_no_params(model.state_dict())))
validate(model, prune_rate, testloader, criterion, optimizer)
print("测试后=================={}".format(get_no_params(model.state_dict())))
结果如下:
训练前==================11178452
训练后==================11178452
测试后==================10062528
上面实验证明,微调之后所有零参数重新得到训练成为非零参数,模型更鲁棒。而真正产生权重剪枝效果发生在测试时前向推理的过程。
依次递增剪枝率,实验结果如下,可见在剪枝率为85%时还有很好的表现效果。
prune rate:0.0, no zero param:11178451, err:4.16
prune rate:5.0, no zero param:10620490, err:4.64
prune rate:10.0, no zero param:10062528, err:4.66
prune rate:15.0, no zero param:9504567, err:4.69
prune rate:20.0, no zero param:8946605, err:4.5
prune rate:25.0, no zero param:8388644, err:4.52
prune rate:30.0, no zero param:7830680, err:4.53
prune rate:35.0, no zero param:7272721, err:4.53
prune rate:40.0, no zero param:6714759, err:4.51
prune rate:45.0, no zero param:6156797, err:4.44
prune rate:50.0, no zero param:5598836, err:4.52
prune rate:55.0, no zero param:5040874, err:4.5
prune rate:60.0, no zero param:4482913, err:4.68
prune rate:65.0, no zero param:3924949, err:4.67
prune rate:70.0, no zero param:3366990, err:4.64
prune rate:75.0, no zero param:2809028, err:4.67
prune rate:80.0, no zero param:2251067, err:5.01
prune rate:85.0, no zero param:1693105, err:5.46
prune rate:90.0, no zero param:1135143, err:8.22
prune rate:95.0, no zero param:577182, err:31.93
下面是关键部分实验,使用 torchstat 测试剪枝率为85%的模型,结果如下,通过结果发现,模型的参数量、计算量没有发生任何改变,即使权重非零参数只有1.7M,只有最初的11.17M的十分之一。
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv_bn_relu.conv 3 32 32 64 32 32 1728.0 0.25 3,473,408.0 1,769,472.0 19200.0 262144.0 0.00% 281344.0
1 conv_bn_relu.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
2 conv_bn_relu.relu 64 32 32 64 32 32 0.0 0.25 65,536.0 65,536.0 262144.0 262144.0 0.00% 524288.0
3 layer1.0.conv1.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 89.46% 671744.0
4 layer1.0.conv1.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
5 layer1.0.conv1.relu 64 32 32 64 32 32 0.0 0.25 65,536.0 65,536.0 262144.0 262144.0 0.00% 524288.0
6 layer1.0.conv2.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
7 layer1.0.conv2.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
8 layer1.0.conv2.relu 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
9 layer1.0.shortcut 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
10 layer1.1.conv1.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
11 layer1.1.conv1.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
12 layer1.1.conv1.relu 64 32 32 64 32 32 0.0 0.25 65,536.0 65,536.0 262144.0 262144.0 0.00% 524288.0
13 layer1.1.conv2.conv 64 32 32 64 32 32 36864.0 0.25 75,431,936.0 37,748,736.0 409600.0 262144.0 0.00% 671744.0
14 layer1.1.conv2.bn 64 32 32 64 32 32 128.0 0.25 262,144.0 131,072.0 262656.0 262144.0 0.00% 524800.0
15 layer1.1.conv2.relu 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
16 layer1.1.shortcut 64 32 32 64 32 32 0.0 0.25 0.0 0.0 0.0 0.0 0.00% 0.0
17 layer2.0.conv1.conv 64 32 32 128 16 16 73728.0 0.12 37,715,968.0 18,874,368.0 557056.0 131072.0 0.00% 688128.0
18 layer2.0.conv1.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
19 layer2.0.conv1.relu 128 16 16 128 16 16 0.0 0.12 32,768.0 32,768.0 131072.0 131072.0 0.00% 262144.0
20 layer2.0.conv2.conv 128 16 16 128 16 16 147456.0 0.12 75,464,704.0 37,748,736.0 720896.0 131072.0 0.00% 851968.0
21 layer2.0.conv2.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
22 layer2.0.conv2.relu 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
23 layer2.0.shortcut.conv_bn.conv 64 32 32 128 16 16 8192.0 0.12 4,161,536.0 2,097,152.0 294912.0 131072.0 0.00% 425984.0
24 layer2.0.shortcut.conv_bn.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
25 layer2.0.shortcut.conv_bn.relu 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
26 layer2.1.conv1.conv 128 16 16 128 16 16 147456.0 0.12 75,464,704.0 37,748,736.0 720896.0 131072.0 9.69% 851968.0
27 layer2.1.conv1.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
28 layer2.1.conv1.relu 128 16 16 128 16 16 0.0 0.12 32,768.0 32,768.0 131072.0 131072.0 0.00% 262144.0
29 layer2.1.conv2.conv 128 16 16 128 16 16 147456.0 0.12 75,464,704.0 37,748,736.0 720896.0 131072.0 0.00% 851968.0
30 layer2.1.conv2.bn 128 16 16 128 16 16 256.0 0.12 131,072.0 65,536.0 132096.0 131072.0 0.00% 263168.0
31 layer2.1.conv2.relu 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
32 layer2.1.shortcut 128 16 16 128 16 16 0.0 0.12 0.0 0.0 0.0 0.0 0.00% 0.0
33 layer3.0.conv1.conv 128 16 16 256 8 8 294912.0 0.06 37,732,352.0 18,874,368.0 1310720.0 65536.0 0.00% 1376256.0
34 layer3.0.conv1.bn 256 8 8 256 8 8 512.0 0.06 65,536.0 32,768.0 67584.0 65536.0 0.00% 133120.0
47 layer3.1.conv2.relu 256 8 8 256 8 8 0.0 0.06 0.0 0.0 0.0 0.0 0.00% 0.0
48 layer3.1.shortcut 256 8 8 256 8 8 0.0 0.06 0.0 0.0 0.0 0.0 0.00% 0.0
49 layer4.0.conv1.conv 256 8 8 512 4 4 1179648.0 0.03 37,740,544.0 18,874,368.0 4784128.0 32768.0 0.00% 4816896.0
50 layer4.0.conv1.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
51 layer4.0.conv1.relu 512 4 4 512 4 4 0.0 0.03 8,192.0 8,192.0 32768.0 32768.0 0.00% 65536.0
52 layer4.0.conv2.conv 512 4 4 512 4 4 2359296.0 0.03 75,489,280.0 37,748,736.0 9469952.0 32768.0 0.00% 9502720.0
53 layer4.0.conv2.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.84% 69632.0
54 layer4.0.conv2.relu 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
55 layer4.0.shortcut.conv_bn.conv 256 8 8 512 4 4 131072.0 0.03 4,186,112.0 2,097,152.0 589824.0 32768.0 0.00% 622592.0
56 layer4.0.shortcut.conv_bn.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
57 layer4.0.shortcut.conv_bn.relu 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
58 layer4.1.conv1.conv 512 4 4 512 4 4 2359296.0 0.03 75,489,280.0 37,748,736.0 9469952.0 32768.0 0.00% 9502720.0
59 layer4.1.conv1.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
60 layer4.1.conv1.relu 512 4 4 512 4 4 0.0 0.03 8,192.0 8,192.0 32768.0 32768.0 0.00% 65536.0
61 layer4.1.conv2.conv 512 4 4 512 4 4 2359296.0 0.03 75,489,280.0 37,748,736.0 9469952.0 32768.0 0.00% 9502720.0
62 layer4.1.conv2.bn 512 4 4 512 4 4 1024.0 0.03 32,768.0 16,384.0 36864.0 32768.0 0.00% 69632.0
63 layer4.1.conv2.relu 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
64 layer4.1.shortcut 512 4 4 512 4 4 0.0 0.03 0.0 0.0 0.0 0.0 0.00% 0.0
65 linear 512 10 5130.0 0.00 10,230.0 5,120.0 22568.0 40.0 0.00% 22608.0
total 11173962.0 7.75 1,112,999,926.0 556,962,816.0 22568.0 40.0 100.00% 57227600.0
=======================================================================================================================================================================
Total params: 11,173,962
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 7.75MB
Total MAdd: 1.11GMAdd
Total Flops: 556.96MFlops
Total MemR+W: 54.58MB
总结
通过对该篇论文的复现,深刻体会到了非结构化剪枝的绝望之处,0参数的数量不会对模型实际作用产生任何正面作用,这也是多篇论文提到为何稀疏矩阵无法产生实际加速效果,稀疏矩阵无法利用BLAS库,。若想加速稀疏矩阵,则需要配套相应的硬件和软件驱动,这是一个大工程,本人就不继续研究了。
不过本篇论文的三步走思想,确实是经典,在结构化剪枝中也发挥着重要的作用,后续会继续更新结构化剪枝相关论文的复现实验。