使用Truncated SVD优化ResNet50全连接层:Distiller项目实践
distiller 项目地址: https://gitcode.com/gh_mirrors/di/distiller
引言
在现代深度学习模型中,全连接层往往包含大量参数,成为模型压缩和加速的重要目标。本文将介绍如何利用Distiller项目中的技术,通过截断奇异值分解(Truncated SVD)方法来优化ResNet50模型的最后一层全连接层,在保持模型精度的同时显著减少参数数量。
Truncated SVD基础原理
奇异值分解(SVD)是线性代数中一种重要的矩阵分解方法,可以将任意矩阵W(m×n)分解为三个矩阵的乘积:
W = USVᵀ
其中:
- U是m×m的正交矩阵
- S是m×n的对角矩阵(奇异值按降序排列)
- Vᵀ是n×n的正交矩阵(V的转置)
Truncated SVD(截断SVD)是SVD的一种近似形式,它只保留前k个最大的奇异值(k<m),舍弃其余较小的奇异值。这种截断操作可以显著减少计算量和存储需求,同时保留矩阵的主要特征。
在神经网络中的应用
对于全连接层y = Wx + b,我们可以应用Truncated SVD来分解权重矩阵W:
- 首先对W进行SVD分解:W = USVᵀ
- 截断保留前k个奇异值,得到近似分解:W ≈ U' S' V'ᵀ
- 将分解后的矩阵重新组合为两个更小的矩阵:A = S'V'ᵀ (k×n) 和 U' (m×k)
这样,原始计算y = Wx + b可以近似为: y ≈ U'(Ax) + b
参数与计算量分析
原始全连接层:
- 参数数量:m×n
- 计算量(FLOPs):m×n
使用Truncated SVD后:
- 参数数量:m×k + k×n = k(m+n)
- 计算量:k(m+n)
压缩条件:当k < mn/(m+n)时,Truncated SVD版本参数更少
对于ResNet50最后一层(m=1000, n=2048):
- 平衡点k ≈ 672
- 当k<672时,Truncated SVD版本参数更少
实践步骤
1. 准备环境与数据
首先需要加载必要的库和ImageNet验证数据集:
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import scipy.stats as ss
import numpy as np
import matplotlib.pyplot as plt
# 数据加载函数
def imagenet_load_data(data_dir, batch_size, num_workers, shuffle=True):
test_dir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(test_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=True)
return test_loader
2. 实现Truncated SVD
def truncated_svd(W, l):
"""使用截断SVD压缩权重矩阵W"""
U, s, V = torch.svd(W, some=True)
Ul = U[:, :l]
sl = s[:l]
V = V.t()
Vl = V[:l, :]
SV = torch.mm(torch.diag(sl), Vl)
return Ul, SV
class TruncatedSVD(nn.Module):
def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):
super().__init__()
self.replaced_gemm = replaced_gemm
self.U, self.SV = truncated_svd(gemm_weights.data,
int(preserve_ratio * gemm_weights.size(0)))
self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
self.fc_u.weight.data = self.U
self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
self.fc_sv.weight.data = self.SV
def forward(self, x):
x = self.fc_sv.forward(x)
x = self.fc_u.forward(x)
return x
3. 替换ResNet50的全连接层
def replace(model):
fc_weights = model.state_dict()['fc.weight']
fc_layer = model.fc
model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4) # 保留40%的奇异值
# 加载并修改ResNet50
resnet50 = models.create_model(pretrained=True, dataset='imagenet',
arch='resnet50', parallel=False)
resnet50 = deepcopy(resnet50)
replace(resnet50)
实验结果
在不同保留比例下的实验结果:
| 保留比例(k) | Top1准确率 | Top5准确率 | 参数数量 | 原始参数数量 | |------------|-----------|-----------|---------|------------| | 80% (800) | 76.02 | 92.86 | 2,438,400 | 2,048,000 | | 70% (700) | 76.03 | 92.85 | 2,133,600 | 2,048,000 | | 60% (600) | 75.98 | 92.82 | 1,828,800 | 2,048,000 | | 50% (500) | 75.78 | 92.77 | 1,524,000 | 2,048,000 | | 40% (400) | 75.65 | 92.75 | 1,219,200 | 2,048,000 |
分析与讨论
-
精度保持:即使在仅保留40%奇异值(k=400)的情况下,Top1准确率仅下降约0.35%,Top5准确率下降约0.25%,说明Truncated SVD能有效保留最重要的特征。
-
参数减少:当k=400时,参数数量减少约40%,从2,048,000降至1,219,200。
-
计算效率:虽然参数减少带来了一定的计算量降低,但由于需要两个连续的矩阵乘法,实际加速效果可能不如参数减少那么显著。
-
进一步优化:可以结合微调(fine-tuning)来恢复部分精度损失,或者与其他压缩技术(如量化)结合使用。
结论
Truncated SVD是一种简单有效的神经网络压缩方法,特别适用于全连接层的优化。通过Distiller项目提供的工具,我们可以方便地在PyTorch中实现这一技术,并在ResNet50等大型模型上验证其效果。实验结果表明,即使不进行微调,Truncated SVD也能在显著减少参数量的同时保持较高的模型精度。
这种方法为模型部署在资源受限设备上提供了可行的解决方案,是模型压缩工具箱中值得掌握的重要技术之一。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考