transformers 笔记:自定义模型(配置+模型+注册为AutoCLass+本地保存加载)

1 配置(Configuration)

1.1 自定义配置

  • 自定义配置类的要点:
    • 必须继承自 PretrainedConfig,以继承 from_pretrained()save_pretrained() 等功能;
    • 构造函数 __init__() 必须接收任意 **kwargs 并传给父类;
    • 添加 model_type 属性,以支持 AutoClass;
    • 可以加入参数校验逻辑。

1.2 保存配置

resnet50d_config = ResnetConfig(block_type="bottleneck", 
                    stem_width=32, 
                    stem_type="deep", 
                    avg_down=True)
resnet50d_config.save_pretrained("custom-resnet")

2 模型结构

  • 模型类需要继承自 PreTrainedModel,并接受配置对象作为输入
  • Transformers 约定模型的所有超参数由配置对象提供
  • 可以构建两种模型:

2.1 裸模型(输出隐藏状态)

2.2 带分类头的模型(支持 Trainer,输出 logits 和 loss)

2.3 加载预训练权重

import timm

resnet50d = ResnetModel(resnet50d_config)
#此时 resnet50d.model 就是一个结构为 ResNet-50d 的模型,但权重是 随机初始化的,没有训练。

pretrained_model = timm.create_model("resnet50d", pretrained=True)
#从 timm 加载已经训练好的 resnet50d 模型

resnet50d.model.load_state_dict(pretrained_model.state_dict())

3 启用 AutoClass 支持

AutoClass API 能自动根据配置加载模型,简化用户调用

需要:

  1. 在配置类中加入 model_type

  2. 在模型类中加入 config_class

  3. 使用 AutoConfig.register()AutoModel.register() 注册。

from transformers import AutoConfig, AutoModel, AutoModelForImageClassification


AutoConfig.register("resnet", ResnetConfig)
#注册自定义配置类 ResnetConfig。
#"resnet" 是 ResnetConfig.model_type,它必须和配置类中的 model_type = "resnet" 一致。
#注册后,用户可以通过 AutoConfig.from_pretrained() 自动加载这个配置类。




AutoModel.register(ResnetConfig, ResnetModel)
#把裸模型类 ResnetModel 绑定到 AutoModel。
'''
这样用户就可以用如下方式加载模型:
model = AutoModel.from_pretrained("your-username/custom-resnet50d", trust_remote_code=True)
'''



AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)
#注册了你带分类头的模型 ResnetModelForImageClassification 到 AutoModelForImageClassification。
'''
用户可以像这样加载:
model = AutoModelForImageClassification.from_pretrained(
    "your-username/custom-resnet50d", trust_remote_code=True
)
'''

4 本地保存& 加载特定模型

 假设已经定义和注册配置和模型,并加载了预训练权重

resnet50d_config = ResnetConfig(block_type="bottleneck", 
    stem_width=32, 
    stem_type="deep", 
    avg_down=True)
#加载自定义config

resnet50d = ResnetModelForImageClassification(resnet50d_config)
#加载自定义model


# 加载预训练权重
import timm
pretrained = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.load_state_dict(pretrained.state_dict())



注册 AutoClass 支持,保存 AutoClass 映射信息

resnet50d_config.register_for_auto_class()
resnet50d.register_for_auto_class("AutoModelForImageClassification")


保存模型和配置到本地

resnet50d.save_pretrained("custom-resnet50d/")
resnet50d_config.save_pretrained("custom-resnet50d/")

4.1 本地重新加载

from transformers import AutoModelForImageClassification

# 加载模型
model = AutoModelForImageClassification.from_pretrained(
    "custom-resnet50d/", trust_remote_code=True
)

由于使用的是自定义模型类,加载时一定要加上trust_remote_code=True

4.2 保存后的本地目录

4.3 为什么要保存config?

  • config 是必须保存的,因为 AutoModel 是依赖 config.json 来决定加载哪个模型类。
  • AutoModel.from_pretrained("path_or_repo")背后的机制是
    • 先加载配置文件 config.json
      • config = AutoConfig.from_pretrained("path_or_repo")
    • 根据 config.model_type 决定使用哪个模型类
      • "model_type": "resnet" → 查找注册的 ResnetModel
    • 再加载权重文件(.bin 或 .safetensors)到模型中

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值