import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = SimpleNet()
# 随机剪枝
prune.random_unstructured(model.fc1, name='weight', amount=0.8)
# 检查剪枝状态
print(prune.is_pruned(model.fc1))
# 获取剪枝掩码
mask = model.fc1.weight_mask
print(mask)
# 移除剪枝
prune.remove(model.fc1, name='weight')
一个简单的剪枝示例,这里剪掉了80%是为了演示效果,
实际操作中可不能这么狠心,尤其对于这种随机剪枝
输出结果:
True
tensor([[1., 0., 1., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 1.],
[1., 0., 0., ..., 0., 1., 0.],
...,
[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 1., 0.]])
Linear(in_features=784, out_features=256, bias=True)
再来看看几种常见的剪枝
torch.nn.utils.prune 提供了多种剪枝方法,常用的有以下几种:
1. L1 剪枝(L1 Unstructured Pruning)
-
描述:基于权重的 L1 范数剪枝,移除权重最小的部分。
-
用法:
python
复制
prune.l1_unstructured(module, name='weight', amount=0.5)
2. 随机剪枝(Random Unstructured Pruning)
-
描述:随机移除权重。
-
用法:
python
复制
prune.random_unstructured(module, name='weight', amount=0.5)
3. L2 结构化剪枝(L2 Structured Pruning)
-
描述:基于权重的 L2 范数剪枝,移除整个神经元或通道。
-
用法:
python
复制
prune.ln_structured(module, name='weight', amount=0.5, n=2, dim=0)
4. 全局剪枝(Global Pruning)
-
描述:跨多个层进行全局剪枝。
-
用法:
python
复制
parameters_to_prune = ((model.fc1, 'weight'), (model.fc2, 'weight')) prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.5)
5. 自定义剪枝(Custom Pruning)
-
描述:自定义剪枝方法。
-
用法:
python
复制
def custom_pruning_mask(weight, amount): mask = torch.ones_like(weight) # 自定义剪枝逻辑 return mask prune.custom_from_mask(module, name='weight', mask=custom_pruning_mask(module.weight, amount=0.5))
6. 移除剪枝(Remove Pruning)
-
描述:移除剪枝,恢复原始权重。
-
用法:
python
复制
prune.remove(module, name='weight')
7. 检查剪枝状态(Check Pruning Status)
-
描述:检查模块是否被剪枝。
-
用法:
python
复制
print(prune.is_pruned(module))
8. 获取剪枝掩码(Get Pruning Mask)
-
描述:获取剪枝掩码。
-
用法:
python
复制
mask = prune.get_mask(module, name='weight')
示例代码
python
复制
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = SimpleNet()
# L1 剪枝
prune.l1_unstructured(model.fc1, name='weight', amount=0.5)
# 检查剪枝状态
print(prune.is_pruned(model.fc1))
# 获取剪枝掩码
mask = prune.get_mask(model.fc1, name='weight')
# 移除剪枝
prune.remove(model.fc1, name='weight')
总结
torch.nn.utils.prune 提供了多种剪枝方法,包括 L1 剪枝、随机剪枝、L2 结构化剪枝、全局剪枝和自定义剪枝等。通过这些方法可以有效减少模型复杂度,提升效率。
随机剪枝是怎么敢的啊?不怕裁员裁到大动脉?
随机剪枝虽然听起来有些冒险,但在实际应用中并不会“裁到大动脉”,原因如下:
1. 剪枝比例控制
-
小比例剪枝:通常只剪枝一小部分权重(如10%-20%),对模型影响有限。
-
逐步剪枝:可以逐步增加剪枝比例,避免一次性剪枝过多。
2. 模型冗余
-
冗余设计:神经网络通常有大量冗余,移除部分权重不会显著影响性能。
-
微调恢复:剪枝后通过微调可以恢复甚至提升模型性能。
3. 随机性优势
-
探索潜力:随机剪枝可能发现新的高效结构。
-
避免偏差:不依赖特定准则,避免引入偏差。
4. 实验验证
-
实验支持:研究表明,随机剪枝在某些情况下效果与基于重要性的剪枝相当。
-
广泛应用:随机剪枝已被用于许多实际任务,证明了其有效性。
剪枝技术:提升模型效率的策略
2908





