pytorch如何freeze模型参数

本文介绍了在PyTorch中如何冻结和解冻模型参数,这对于迁移学习和自监督学习至关重要。通过提供的函数`set_freeze_by_names`和`set_freeze_by_idxs`,可以根据层名或索引来方便地控制模型参数是否参与反向传播。这些函数允许用户选择性地更新模型的特定部分,在预训练模型上进行微调或专注于训练新添加的层。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch如何freeze模型参数

在做迁移学习或者自监督学习时,一般先预训练一个模型,再将该模型参数作为目标任务模型的初始化参数,或者直接freeze预训练模型,不再更新其参数。

今天记录下如何pytorch freeze模型参数

我是参考知乎一个文章,总结的很完整,我直接拿过来用了,原文出处为

https: // www.zhihu.com / question / 311095447 / answer / 589307812
from collections.abc import Iterable


def set_freeze_by_names(model, layer_names, freeze=True):
    if not isinstance(layer_names, Iterable):
        layer_names = [layer_names]
    for name, child in model.named_children():
        if name not in layer_names:
            continue
        for param in child.parameters():
            #print(param.name)
            param.requires_grad = not freeze


def freeze_by_names(model, layer_names):
    set_freeze_by_names(model, layer_names, True)


def unfreeze_by_names(model, layer_names):
    set_freeze_by_names(model, layer_names, False)


def set_freeze_by_idxs(model, idxs, freeze=True):
    if not isinstance(idxs, Iterable):
        idxs = [idxs]
    num_child = len(list(model.children()))
    idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))
    for idx, child in enumerate(model.children()):
        if idx not in idxs:
            continue
        for param in child.parameters():
            param.requires_grad = not freeze


def freeze_by_idxs(model, idxs):
    set_freeze_by_idxs(model, idxs, True)


def unfreeze_by_idxs(model, idxs):
    set_freeze_by_idxs(model, idxs, False)

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值