魔改模型的正确姿势(如何高效继承模型权重)

  参考资料:

  正则表达式语法

  众所周知,在人工智能领域经常会有权重继承的需求。因为继承一个好的预训练模型的权重,能够极大程度地加快模型的收敛速度。很多时候,能否很好地继承(魔改)已有的预训练好的模型,直接决定了最终模型的效果。然而,这一过程并不是一目了然、有手就行的,如果你想大改、自己实现一个算法框架,同时又想用上已有模型的能力,就会面临这种相当困难且耗时的工作。下面我分小节讲一讲我在自己魔改模型的过程中整理出来的workflow,方便日后查阅,并供读者参考。

  1. 分清自己的目标

  魔改魔改,不改怎么叫魔改。改有两个方向,一种是基于别人的代码和权重,将别人网络中的一些层进行替换,然后重新进行训练。这种一般来说比较简单,因为state_dict的key大部分都能对上(毕竟还是那个源文件)。如果需要修改的层还比较多的话,可能就需要用到module的搜索了,比较推荐的是使用nn.Module自带named_modules方法进行便利搜索。

  另一种就是基于自己的风格,用自己的代码框架构建模型,然后从别人的权重中把有用的权重值“薅过来”。这种改法就比较困难,但好处就是通用性比较强。如果你有一套成熟的训练体系和配置文件体系,那么使用这种方式显然更加优雅且统一。本文着重讲解的也是这种风格,此时源模型和目标模型基本上是异构的。以一个实际例子为例,StableDiffusion3的VAE是针对于2D图片来设计的,我想把它魔改成能够处理3D视频的视频VAE。

  2. 检视目标Model的State_dict和源权重文件的State_dict

  首先,初始化你的模型(nn.Module),并加载源权重文件(ckpt或safetensors)。使用下面的代码打印出所有的层,检视一下key有什么不同:

for k,v in VAEModel.state_dict().items():
    print(k, "》》》", v.shape)

  最好是存成txt文件放到大屏幕上分个屏,人工观察,在这个过程中可以手工调整一下位置和分段,主要是要你对模型的结构和对应关系有个清晰的认识。

  3. 通过正则表达式替换权重State_dict中的k和v

  人工观察之后发现,大部分的名字都是完全可以对上的,现在我们需要把state_dict进行修改,此时就可以请出我们的利器:正则表达式,因为它可以极大程度上减少我们的重复替换工作,例如写无数个starts_with。就非常呆。。。比如,我发现源的encoder.down_blocks.0.resnets.0.norm1.weight和目标的encoder.down_list.0.block.0.norm1.norm.weight是对应的关系,就可以用如下的代码进行替换:

pat1  = r'(down|up)_blocks\.(\d+)\.resnets\.(\d+)\.norm([12])'
repl1 = r'\1_list.\2.block.\3.norm\4.norm'

pat_list = [(pat1, repl1)]

A = safetensors.torch.load_file("XXX")
B = deepcopy(A)

for k, v in B.items():
    for pat, repl in pat_list:
        if re.search(pat, k):
            ## 魔改权重 v
            pass
        
        A[re.sub(pat, repl, k)] = v
        del A[k]

Model.load_state_dict(A, strict=True)

  4. 对精度

  在我们成功继承权重之后,对精度一般也是必不可少的。比如我魔改的3D VAE,我就希望它在继承权重之后能够很好地图片。此时我会使用源模型推理一遍,然后再用魔改之后的模型推理一遍,保存中间隐藏层的特征。最后使用torch自带的方法计算这些隐藏层是否完全对得上。一般来说是从高层到底层,从粗粒度到细粒度。如果你发现两个代码框架对于某个计算逻辑抽象的层级不一样,比如模型甲多包了一层,此时精度误差又刚好发生在这多包的一层中,我的建议是逐层展开。仍然逐层打印隐藏层的激活值,存成.pt文件,对精度,直到最后精度一致。建议使用Jupyter notebook,配合autoreload,效率更高哦。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值