import os
import datetime
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor
import copy
from collections import OrderedDict
import transforms
from network_files import FasterRCNN, AnchorsGenerator
from my_dataset import VOCDataSet
from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
from train_utils import train_eval_utils as utils
# ---------------------------- ECA注意力模块 ----------------------------
class ECAAttention(nn.Module):
def __init__(self, channels, kernel_size=3):
super(ECAAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: [B, C, H, W]
b, c, h, w = x.size()
# 全局平均池化
y = self.avg_pool(x).view(b, c, 1)
# 一维卷积实现跨通道交互
y = self.conv(y.transpose(1, 2)).transpose(1, 2).view(b, c, 1, 1)
# 生成注意力权重
y = self.sigmoid(y)
return x * y.expand_as(x)
# ---------------------------- 特征融合模块 ----------------------------
class FeatureFusionModule(nn.Module):
def __init__(self, student_channels, teacher_channels):
super().__init__()
# 1x1卷积用于通道对齐
self.teacher_proj = nn.Conv2d(teacher_channels, student_channels, kernel_size=1)
# 注意力机制
self.attention = nn.Sequential(
nn.Conv2d(student_channels * 2, student_channels // 8, kernel_size=1),
nn.ReLU(),
nn.Conv2d(student_channels // 8, 2, kernel_size=1),
nn.Softmax(dim=1)
)
# ECA注意力模块
self.eca = ECAAttention(student_channels, kernel_size=3)
def forward(self, student_feat, teacher_feat):
# 调整教师特征的空间尺寸以匹配学生特征
if student_feat.shape[2:] != teacher_feat.shape[2:]:
teacher_feat = F.interpolate(teacher_feat, size=student_feat.shape[2:], mode='bilinear',
align_corners=False)
# 通道投影
teacher_feat_proj = self.teacher_proj(teacher_feat)
# 特征拼接
concat_feat = torch.cat([student_feat, teacher_feat_proj], dim=1)
# 计算注意力权重
attn_weights = self.attention(concat_feat)
# 加权融合
fused_feat = attn_weights[:, 0:1, :, :] * student_feat + attn_weights[:, 1:2, :, :] * teacher_feat_proj
# 应用ECA注意力
fused_feat = self.eca(fused_feat)
return fused_feat
# ---------------------------- 简化版FPN实现 ----------------------------
class SimpleFPN(nn.Module):
def __init__(self, in_channels_list, out_channels):
super().__init__()
self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
# FPN输出层的ECA模块
self.eca_blocks = nn.ModuleList()
# 为每个输入特征创建内部卷积和输出卷积
for in_channels in in_channels_list:
inner_block = nn.Conv2d(in_channels, out_channels, 1)
layer_block = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
# 为每个FPN输出添加ECA模块
self.eca_blocks.append(ECAAttention(out_channels, kernel_size=3))
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 假设x是一个有序的字典,包含'c2', 'c3', 'c4', 'c5'特征
c2, c3, c4, c5 = x['c2'], x['c3'], x['c4'], x['c5']
# 处理最顶层特征
last_inner = self.inner_blocks[3](c5)
results = []
# 处理P5
p5 = self.layer_blocks[3](last_inner)
p5 = self.eca_blocks[3](p5) # 应用ECA注意力
results.append(p5)
# 自顶向下路径
for i in range(2, -1, -1):
inner_lateral = self.inner_blocks[i](x[f'c{i + 2}'])
feat_shape = inner_lateral.shape[-2:]
last_inner = F.interpolate(last_inner, size=feat_shape, mode="nearest")
last_inner = last_inner + inner_lateral
# 应用卷积和ECA注意力
layer_out = self.layer_blocks[i](last_inner)
layer_out = self.eca_blocks[i](layer_out) # 应用ECA注意力
results.insert(0, layer_out)
# 返回有序的特征字典
return {
'p2': results[0],
'p3': results[1],
'p4': results[2],
'p5': results[3]
}
# ---------------------------- 通道剪枝器类 ----------------------------
class ChannelPruner:
def __init__(self, model, input_size, device='cpu'):
self.model = model
self.input_size = input_size # (B, C, H, W)
self.device = device
self.model.to(device)
self.prunable_layers = self._identify_prunable_layers() # 识别可剪枝层
self.channel_importance = {} # 存储各层通道重要性
self.mask = {} # 剪枝掩码
self.reset()
def reset(self):
"""重置剪枝状态"""
self.pruned_model = copy.deepcopy(self.model)
self.pruned_model.to(self.device)
for name, module in self.prunable_layers.items():
self.mask[name] = torch.ones(module.out_channels, dtype=torch.bool, device=self.device)
def _identify_prunable_layers(self):
"""识别可剪枝的卷积层(优先中间层,如layer2、layer3)"""
prunable_layers = OrderedDict()
for name, module in self.model.named_modules():
# 主干网络层(layer1-layer4)
if "backbone.backbone." in name:
# layer1(底层,少剪或不剪)
if "layer1" in name:
if isinstance(module, nn.Conv2d) and module.out_channels > 1:
prunable_layers[name] = module
# layer2和layer3(中间层,优先剪枝)
elif "layer2" in name or "layer3" in name:
if isinstance(module, nn.Conv2d) and module.out_channels > 1:
prunable_layers[name] = module
# layer4(高层,少剪)
elif "layer4" in name:
if isinstance(module, nn.Conv2d) and module.out_channels > 1:
prunable_layers[name] = module
# FPN层(inner_blocks和layer_blocks)
elif "fpn.inner_blocks" in name or "fpn.layer_blocks" in name:
if isinstance(module, nn.Conv2d) and module.out_channels > 1:
prunable_layers[name] = module
# FeatureFusionModule的teacher_proj层
elif "feature_fusion." in name and "teacher_proj" in name:
if isinstance(module, nn.Conv2d) and module.out_channels > 1:
prunable_layers[name] = module
return prunable_layers
def compute_channel_importance(self, dataloader=None, num_batches=10):
"""基于激活值计算通道重要性"""
self.pruned_model.eval()
for name in self.prunable_layers.keys():
self.channel_importance[name] = torch.zeros(
self.prunable_layers[name].out_channels, device=self.device)
with torch.no_grad():
if dataloader is None:
# 使用随机数据计算
for _ in range(num_batches):
inputs = torch.randn(self.input_size, device=self.device)
self._forward_once([inputs]) # 包装为列表以匹配数据加载器格式
else:
# 使用验证集计算
for inputs, _ in dataloader:
if num_batches <= 0:
break
# 将图像列表移至设备
inputs = [img.to(self.device) for img in inputs]
self._forward_once(inputs)
num_batches -= 1
def _forward_once(self, inputs):
"""前向传播并记录各层激活值"""
def hook(module, input, output, name):
# 计算通道重要性(绝对值均值)
channel_impact = torch.mean(torch.abs(output), dim=(0, 2, 3))
self.channel_importance[name] += channel_impact
hooks = []
for name, module in self.pruned_model.named_modules():
if name in self.prunable_layers:
hooks.append(module.register_forward_hook(lambda m, i, o, n=name: hook(m, i, o, n)))
# 修改这里以正确处理目标检测模型的输入格式
self.pruned_model(inputs)
for hook in hooks:
hook.remove()
def prune_channels(self, layer_prune_ratios):
"""按层剪枝(支持不同层不同比例)"""
for name, ratio in layer_prune_ratios.items():
if name not in self.prunable_layers:
continue
module = self.prunable_layers[name]
num_channels = module.out_channels
num_prune = int(num_channels * ratio)
if num_prune <= 0:
continue
# 获取最不重要的通道索引(按重要性从小到大排序)
importance = self.channel_importance[name]
_, indices = torch.sort(importance)
prune_indices = indices[:num_prune]
# 更新掩码
self.mask[name][prune_indices] = False
def apply_pruning(self):
"""应用剪枝掩码到模型"""
pruned_model = copy.deepcopy(self.model)
pruned_model.to(self.device)
# 存储每层的输入通道掩码(用于处理非连续层的依赖关系)
output_masks = {}
# 第一遍:处理所有可剪枝层,记录输出通道掩码
for name, module in pruned_model.named_modules():
if name in self.mask:
curr_mask = self.mask[name]
# 处理卷积层
if isinstance(module, nn.Conv2d):
# 剪枝输出通道
module.weight.data = module.weight.data[curr_mask]
if module.bias is not None:
module.bias.data = module.bias.data[curr_mask]
module.out_channels = curr_mask.sum().item()
# 记录该层的输出通道掩码
output_masks[name] = curr_mask
print(f"处理层: {name}, 原始输出通道: {len(curr_mask)}, 剪枝后: {curr_mask.sum().item()}")
# 处理ECA注意力模块
elif isinstance(module, ECAAttention):
# ECA模块不需要修改,因为它不改变通道数
pass
# 处理FeatureFusionModule的teacher_proj
elif "teacher_proj" in name:
# 输入通道来自教师特征,输出通道由curr_mask决定
module.weight.data = module.weight.data[curr_mask]
module.out_channels = curr_mask.sum().item()
# 记录该层的输出通道掩码
output_masks[name] = curr_mask
# 第二遍:处理非剪枝层的输入通道
for name, module in pruned_model.named_modules():
if name in output_masks:
# 已在第一遍处理过,跳过
continue
# 对于卷积层,查找其输入来源的层的掩码
if isinstance(module, nn.Conv2d):
# 尝试查找前一层的输出掩码
prev_mask = None
# 简化的查找逻辑,实际情况可能需要更复杂的实现
# 这里假设命名约定能帮助我们找到前一层
for possible_prev_name in reversed(list(output_masks.keys())):
if possible_prev_name in name or (
"backbone" in name and "backbone" in possible_prev_name):
prev_mask = output_masks[possible_prev_name]
break
# 应用输入通道掩码
if prev_mask is not None and prev_mask.shape[0] == module.weight.shape[1]:
print(
f"应用输入掩码到层: {name}, 原始输入通道: {module.weight.shape[1]}, 剪枝后: {prev_mask.sum().item()}")
module.weight.data = module.weight.data[:, prev_mask]
module.in_channels = prev_mask.sum().item()
else:
print(f"警告: 无法为层 {name} 找到匹配的输入掩码,保持原始通道数")
return pruned_model
def evaluate_model(self, dataloader, model=None):
"""评估模型性能(mAP)"""
if model is None:
model = self.pruned_model
model.eval()
# 调用评估函数
coco_info = utils.evaluate(model, dataloader, device=self.device)
return coco_info[1] # 返回mAP值
def get_model_size(self, model=None):
"""获取模型大小(MB)"""
if model is None:
model = self.pruned_model
torch.save(model.state_dict(), "temp.pth")
size = os.path.getsize("temp.pth") / (1024 * 1024)
os.remove("temp.pth")
return size
def create_model(num_classes):
# ---------------------------- 学生模型定义 ----------------------------
try:
# 尝试使用新版本API
backbone = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
except AttributeError:
# 旧版本API
backbone = torchvision.models.resnet18(pretrained=True)
# 提取多个层作为特征融合点
return_nodes = {
"layer1": "c2", # 对应FPN的P2 (1/4)
"layer2": "c3", # 对应FPN的P3 (1/8)
"layer3": "c4", # 对应FPN的P4 (1/16)
"layer4": "c5", # 对应FPN的P5 (1/32)
}
backbone = create_feature_extractor(backbone, return_nodes=return_nodes)
# 添加简化版FPN
fpn = SimpleFPN([64, 128, 256, 512], 256)
# 创建一个包装模块,将backbone和FPN组合在一起
class BackboneWithFPN(nn.Module):
def __init__(self, backbone, fpn):
super().__init__()
self.backbone = backbone
self.fpn = fpn
self.out_channels = 256 # FPN输出通道数
def forward(self, x):
x = self.backbone(x)
x = self.fpn(x)
return x
# 替换原始backbone为带FPN的backbone
backbone_with_fpn = BackboneWithFPN(backbone, fpn)
# 增加更多anchor尺度和宽高比
anchor_sizes = ((16, 32, 48), (32, 64, 96), (64, 128, 192), (128, 256, 384), (256, 512, 768))
aspect_ratios = ((0.33, 0.5, 1.0, 2.0, 3.0),) * len(anchor_sizes)
anchor_generator = AnchorsGenerator(
sizes=anchor_sizes,
aspect_ratios=aspect_ratios
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=['p2', 'p3', 'p4', 'p5'],
output_size=[7, 7],
sampling_ratio=2
)
model = FasterRCNN(
backbone=backbone_with_fpn,
num_classes=num_classes,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler
)
# 添加多尺度特征融合模块
model.feature_fusion = nn.ModuleDict({
'p2': FeatureFusionModule(256, 256),
'p3': FeatureFusionModule(256, 256),
'p4': FeatureFusionModule(256, 256),
'p5': FeatureFusionModule(256, 256),
})
return model
def main(args):
# 确保输出目录存在
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
# ---------------------------- 模型剪枝流程 ----------------------------
print("开始模型剪枝...")
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# 加载已训练的模型
model = create_model(num_classes=args.num_classes + 1)
checkpoint = torch.load(args.resume, map_location=device)
model.load_state_dict(checkpoint["model"])
model.to(device)
# 输入尺寸(需与训练时一致)
input_size = (1, 3, 800, 600) # (B, C, H, W)
# 加载验证集
val_dataset = VOCDataSet(args.data_path, "2012", transforms.Compose([transforms.ToTensor()]), "val.txt")
val_data_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=4,
collate_fn=val_dataset.collate_fn
)
# 初始化剪枝器
pruner = ChannelPruner(model, input_size, device=device)
# 计算通道重要性
print("计算通道重要性...")
pruner.compute_channel_importance(dataloader=val_data_loader, num_batches=10)
# 定义分层剪枝比例(中间层多剪,底层和高层少剪)
layer_prune_ratios = {
# 主干网络
"backbone.backbone.layer1": args.layer1_ratio, # 底层:剪10%
"backbone.backbone.layer2": args.layer2_ratio, # 中间层:剪30%
"backbone.backbone.layer3": args.layer3_ratio, # 中间层:剪30%
"backbone.backbone.layer4": args.layer4_ratio, # 高层:剪10%
# FPN层
"fpn.inner_blocks": args.fpn_inner_ratio, # FPN内部卷积:剪20%
"fpn.layer_blocks": args.fpn_layer_ratio, # FPN输出卷积:剪20%
# FeatureFusionModule的teacher_proj
"feature_fusion.p2.teacher_proj": args.ff_p2_ratio,
"feature_fusion.p3.teacher_proj": args.ff_p3_ratio,
"feature_fusion.p4.teacher_proj": args.ff_p4_ratio,
"feature_fusion.p5.teacher_proj": args.ff_p5_ratio,
}
# 执行剪枝
print("执行通道剪枝...")
pruner.prune_channels(layer_prune_ratios)
# 应用剪枝并获取新模型
pruned_model = pruner.apply_pruning()
# 评估剪枝前后的性能和模型大小
original_size = pruner.get_model_size(model)
pruned_size = pruner.get_model_size(pruned_model)
original_map = pruner.evaluate_model(val_data_loader, model)
pruned_map = pruner.evaluate_model(val_data_loader, pruned_model)
print(f"原始模型大小: {original_size:.2f} MB")
print(f"剪枝后模型大小: {pruned_size:.2f} MB")
print(f"模型压缩率: {100 * (1 - pruned_size / original_size):.2f}%")
print(f"原始mAP: {original_map:.4f}, 剪枝后mAP: {pruned_map:.4f}")
# 保存剪枝后的模型
pruned_model_path = os.path.join(args.output_dir, f"pruned_resNetFpn_{args.num_classes}classes.pth")
torch.save(pruned_model.state_dict(), pruned_model_path)
print(f"剪枝后的模型已保存至: {pruned_model_path}")
# 对剪枝后的模型进行微调(默认启用)
print("准备对剪枝后的模型进行微调...")
# 加载训练集
train_dataset = VOCDataSet(args.data_path, "2012",
transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)
]), "train.txt")
# 创建训练数据加载器
train_sampler = torch.utils.data.RandomSampler(train_dataset)
train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=4,
collate_fn=train_dataset.collate_fn
)
# 定义优化器
params = [p for p in pruned_model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params,
lr=args.lr,
momentum=0.9,
weight_decay=0.0005
)
# 定义学习率调度器
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=args.lr_step_size,
gamma=args.lr_gamma
)
# 微调训练
print(f"开始微调: 批次大小={args.batch_size}, 学习率={args.lr}, 轮数={args.epochs}")
best_map = 0.0
for epoch in range(args.epochs):
# 训练一个epoch
utils.train_one_epoch(pruned_model, optimizer, train_data_loader,
device, epoch, print_freq=50)
# 更新学习率
lr_scheduler.step()
# 评估模型
coco_info = utils.evaluate(pruned_model, val_data_loader, device=device)
# 保存当前最佳模型
map_50 = coco_info[1] # COCO评估指标中的IoU=0.50时的mAP
if map_50 > best_map:
best_map = map_50
torch.save({
'model': pruned_model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
}, os.path.join(args.output_dir, f"finetuned_pruned_best.pth"))
print(f"Epoch {epoch + 1}/{args.epochs}, mAP@0.5: {map_50:.4f}")
print(f"微调完成,最佳mAP@0.5: {best_map:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
# ---------------------------- 剪枝参数 ----------------------------
parser.add_argument('--device', default='cuda:0', help='device')
parser.add_argument('--data-path', default='./', help='dataset')
parser.add_argument('--num-classes', default=6, type=int, help='num_classes')
parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
parser.add_argument('--resume', default='./save_weights/resNetFpn-zuizhong.pth', type=str,
help='resume from checkpoint')
# 分层剪枝比例参数
parser.add_argument('--layer1-ratio', default=0.1, type=float, help='layer1 pruning ratio')
parser.add_argument('--layer2-ratio', default=0.5, type=float, help='layer2 pruning ratio')
parser.add_argument('--layer3-ratio', default=0.5, type=float, help='layer3 pruning ratio')
parser.add_argument('--layer4-ratio', default=0.1, type=float, help='layer4 pruning ratio')
parser.add_argument('--fpn-inner-ratio', default=0.2, type=float, help='FPN inner blocks pruning ratio')
parser.add_argument('--fpn-layer-ratio', default=0.2, type=float, help='FPN layer blocks pruning ratio')
parser.add_argument('--ff-p2-ratio', default=0.2, type=float, help='Feature fusion P2 pruning ratio')
parser.add_argument('--ff-p3-ratio', default=0.2, type=float, help='Feature fusion P3 pruning ratio')
parser.add_argument('--ff-p4-ratio', default=0.2, type=float, help='Feature fusion P4 pruning ratio')
parser.add_argument('--ff-p5-ratio', default=0.2, type=float, help='Feature fusion P5 pruning ratio')
# ---------------------------- 微调参数 ----------------------------
parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run')
parser.add_argument('--batch-size', default=8, type=int, help='batch size')
parser.add_argument('--lr', default=0.05, type=float, help='initial learning rate')
parser.add_argument('--lr-step-size', default=3, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
args = parser.parse_args()
main(args)以上代码没有实现成功的剪枝,请你仔细检查并修改,帮助我实现成功的剪枝。
最新发布