在 MMDetection 中,如果需要设置某个模块下的所有权重不参与梯度计算,可以通过以下步骤实现。具体实现方式包括冻结参数以及确保这些参数不会被优化器更新。
方法 1:设置 requires_grad=False
这是最直接的方法,通过将模块中的所有参数的 requires_grad
属性设置为 False
,使其在训练过程中不计算梯度。
代码示例:
# 冻结当前模块下的所有参数
for param in model.module_name.parameters():
param.requires_grad = False
在 MMDetection 中,假设你有一个 teacher
模块需要冻结:
# 假设 model.teacher 是需要冻结的模块
for param in model.teacher.parameters():
param.requires_grad = False
这样处理后:
- 该模块的权重在前向传播时不会参与梯度计算。
- 这些参数不会被优化器更新。
方法 2:过滤优化器的参数
通过调整优化器的参数列表,仅让需要计算梯度的参数参与优化。
代码示例:
# 过滤出需要计算梯度的参数
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4
)
如果使用 MMDetection 配置文件中的 optim_wrapper
,可以使用 paramwise_cfg
来显式指定:
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=1e-4),
paramwise_cfg=dict(
custom_keys={
'teacher.': dict(lr_mult=0.0, decay_mult=0.0), # 冻结 teacher 模块
}
)
)
这里的 custom_keys
配置会自动设置 teacher
模块中的参数不更新。
方法 3:使用 torch.no_grad()
包裹前向传播
如果某个模块的参数不仅不需要计算梯度,也不需要参与反向传播,可以在前向传播中使用 torch.no_grad()
包裹计算逻辑。
代码示例:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.teacher = TeacherModule()
self.student = StudentModule()
def forward(self, x):
with torch.no_grad(): # 冻结 teacher 的前向传播
teacher_output = self.teacher(x)
student_output = self.student(x)
return student_output
这样可以确保 teacher
的参数在前向传播时不记录计算图,进一步减少显存消耗。
方法 4:设置 eval()
模式
如果某个模块是预训练的,且在当前任务中只需要作为一个固定的特征提取器,可以将其切换到 eval()
模式。eval()
模式会自动关闭 dropout
和 BatchNorm
的统计更新。
代码示例:
model.teacher.eval() # 切换 teacher 模块为 eval 模式
需要注意的是,eval()
仅影响前向传播中的某些特定操作(如 BatchNorm
和 Dropout
),并不自动设置 requires_grad=False
。
方法 5:自定义模型权重冻结函数
如果需要灵活控制,可以编写一个函数来自动冻结指定模块的参数。
代码实现:
def freeze_module(module):
for param in module.parameters():
param.requires_grad = False
# 冻结特定模块
freeze_module(model.teacher)
验证冻结效果
在训练开始前,可以通过以下代码检查冻结是否成功:
for name, param in model.named_parameters():
print(f"{name}: requires_grad={param.requires_grad}")
你应该看到指定模块的参数对应的 requires_grad=False
。
总结
在 MMDetection 中冻结某个模块下的所有权重可以通过以下方式实现:
- 设置
requires_grad=False
:直接控制参数是否需要计算梯度。 - 过滤优化器参数:通过优化器配置避免优化不需要的参数。
- 使用
torch.no_grad()
:防止在前向传播中记录计算图。 - 设置
eval()
模式:适用于预训练模块,只需固定特征提取。 - 自定义冻结函数:更灵活地批量操作模型参数。