import os
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from vgg import vgg
import numpy as np
#这些包的作用在主函数中已经讲过
# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar10',
help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=1200, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--percent', type=float, default=0.5,
help='scale sparse rate (default: 0.5)') #尺度稀疏率
parser.add_argument('--model', default='', type=str, metavar='PATH',
help='path to raw trained model (default: none)') #训练模型没有路径
parser.add_argument('--save', default='', type=str, metavar='PATH',
help='path to save prune model (default: none)') #也没有保存
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
model = vgg()
if args.cuda:
model.cuda()
#大部分代码已在主函数中有解释
if args.model:
if os.path.isfile(args.model):
print("=> loading checkpoint '{}'".format(args.model))
checkpoint = torch.load(args.model)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
.format(args.model, checkpoint['epoch'], best_prec1))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
print(model)
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
for m in model.modules():#迭代器,会遍历model中所有的子层 #初始化,根据模型的类型进行不同的初始化
if isinstance(m, nn.BatchNorm2d):#BN
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]
#**********************************************************预剪枝************************************
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.clone()
#要保留的通道
mask = weight_copy.abs().gt(thre).float().cuda()
#剪枝掉的通道个数
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')
pruned_ratio = pruned/total
print('Pre-processing Successful!')
# simple test model after Pre-processing prune (simple set BN scales to zeros)
#******************************************预剪枝后model测试**************************************
def test():
kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
#加载测试集
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model.eval()# 训练时的初始化
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
#记录类别预测正确的个数
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
#计算准确率
print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
return correct / float(len(test_loader.dataset))
test()
# Make real prune
print(cfg)
newmodel = vgg(cfg=cfg)#定义新模型,结构和原始模型一样,但通道数变了
newmodel.cuda()
layer_id_in_cfg = 0
#定义原始模型和新模型的每一层保留通道索引的mask
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
if __name__ == '__main__':
for [m0, m1] in zip(model.modules(), newmodel.modules()):
#对BN层和conv层都要剪枝
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))#从数组的形状中删除单维度条目,即把shape中的1的维度去掉
#返回非0的数组元组的索引,其中a时要索引数组的条件
m1.weight.data = m0.weight.data[idx1].clone()
m1.bias.data = m0.bias.data[idx1].clone()
m1.running_mean = m0.running_mean[idx1].clone()
m1.running_var = m0.running_var[idx1].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
#注意start_mask 在end_mask的前一层,这个会在裁剪conv2d的时候用到
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
#注意卷积核Tensor维度为[n,c,w,h],两个卷积层连接,下一层的输入维度'n'就等于当前层的c
w = m0.weight.data[:, idx0, :, :].clone()
w = w[idx1, :, :, :].clone()
m1.weight.data = w.clone()
# m1.bias.data = m0.bias.data[idx1].clone()
elif isinstance(m0, nn.Linear):#如果到不需要没有剪枝的BN层,就直接复制
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
m1.weight.data = m0.weight.data[:, idx0].clone()
#主函数对网络中的层做处理,输出最后的结果 ,
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)
print(newmodel)
model = newmodel
【无标题】
最新推荐文章于 2023-06-09 14:28:43 发布