论文阅读《Few-shot Object Detection via Feature Reweighting》

提出一种小样本模型,通过特征重加权快速适应新类别。利用大量有标签数据提取元特征,结合少量新类数据,实现高效检测。

Few-shot Object Detection via Feature Reweighting

提出了一种检测新颖类别的小样本模型,该新颖类别仅包含少数数据。充分利用基类(base classes)中有标签数据,使用一个特征提取模块和一个 reweighting 模块在一个一阶段检测网络中,来快速适应以达到实现新颖类别检测的目的。

元特征提取模块利用大量有标签数据提取那些具有通用性的,可以用来辅助检测新类别的元特征。从具有足够样本的基类中提取元特征来检测新的对象类。

reweighting 模块将新颖类别中的少量 support set 转换成一个全局向量,这个全局向量是元特征中对检测这些新颖类别物体非常重要或者十分相关的特征,实现将元特征的知识迁移到新类别。利用这个模块,可以激发一些用于检测新颖目标的元特征,从而辅助检测。

具体来说,给定一张 query image 和一些同类别的 support image,用训练好的特征提取模块提取 query 的元特征。reweighting 模块学习 support 中每个物体类别的全局特征,得到每个类别的 reweighting vector。再用这些 vector 通过调整对应的系数来调整 query 的元特征的分布,这样 query 的元特征融合了 support 里的信息,更加适合用来实现新颖类别的检测,这么多的 reweighting vector 使得 query image 里的那些对识别新物体重要并且贡献大的元特征被更好地利用起来。也可以理解为 reweighting vector 对不同的类别产生不同的激活系数,来调整 query 通过特征提取模块得到的元特征。

本文提出的模型是端到端的训练模型,该模型分为两个阶段进行训练。第一个阶段先学习特征提取模块提取元特征的能力和 reweighting 模块,第二阶段再用新颖类别的数据对检测模型(?)进行微调。精心设计了一个损失函数来解决无关类别的问题。

具体实现

问题定义

训练时有两类数据,基类和新颖类(novel classes),基类中包含大量有标签数据,而新颖类则很少。

### 实现断点续训功能的关键要素 在 Frustratingly Simple Few-Shot Object Detection (FSCE) 模型中实现断点继续训练的功能,需要通过保存和加载模型的状态字典以及优化器的状态字典来完成。以下是实现这一功能的具体方法: #### 1. **保存模型与优化器状态** 在训练过程中,可以通过 `torch.save` 方法保存模型的权重和优化器的状态。以下是一个示例代码,展示如何在训练过程中定期保存检查点: ```python import torch # 假设 model 是你的 FSCE 模型实例,optimizer 是优化器实例 def save_checkpoint(model, optimizer, epoch, loss, path): """ 保存模型、优化器状态及训练信息。 """ torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, path) ``` 上述代码将模型的状态字典(`model.state_dict()`)、优化器的状态字典(`optimizer.state_dict()`)、当前的训练轮数(`epoch`)以及损失值(`loss`)保存到指定路径中[^1]。 #### 2. **加载模型与优化器状态** 在恢复训练时,需要从保存的检查点中加载模型和优化器的状态。以下是一个加载检查点的示例代码: ```python def load_checkpoint(model, optimizer, path): """ 加载模型、优化器状态及训练信息。 """ checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] return epoch, loss ``` 通过上述代码,可以从保存的检查点中恢复模型的权重、优化器的状态、训练轮数以及损失值,从而继续训练[^1]。 #### 3. **整合断点续训逻辑** 为了实现自动化的断点续训功能,可以在训练脚本中添加条件判断,检查是否存在之前的检查点文件。如果存在,则加载该文件并从中恢复训练;否则,从头开始训练。以下是一个完整的训练流程示例: ```python import os # 定义检查点路径 checkpoint_path = "fsce_checkpoint.pth" # 初始化模型和优化器 model = ... # 初始化 FSCE 模型 optimizer = ... # 初始化优化器 # 判断是否已有检查点 if os.path.exists(checkpoint_path): print("Loading checkpoint...") start_epoch, prev_loss = load_checkpoint(model, optimizer, checkpoint_path) else: print("Starting training from scratch...") start_epoch = 0 prev_loss = None # 训练循环 for epoch in range(start_epoch, num_epochs): for data in dataloader: # 训练步骤 ... # 保存检查点 save_checkpoint(model, optimizer, epoch, current_loss, checkpoint_path) ``` 上述代码展示了如何在训练过程中定期保存检查点,并在启动训练脚本时检查是否存在之前的检查点文件以决定是否恢复训练。 --- ### 注意事项 - 在保存和加载检查点时,确保使用相同的设备(CPU 或 GPU)。如果需要在不同设备间切换,可以使用 `map_location` 参数。例如:`torch.load(path, map_location=torch.device('cpu'))`。 - 如果训练过程中涉及学习率调度器(`lr_scheduler`),也需要将其状态保存和加载,类似于优化器的操作。 ---
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值