【Pytorch】prototypical network原型网络小样本图像分类简述及其实现

基本概念

小样本学习(Few-Shot Learning, FSL)任务,顾名思义,就是能够仅通过一个或几个示例就快速建立对新概念的认知能力。实现小样本学习的方式也有很多,比如:度量学习、数据增强、预训练模型、元学习等等,其中元学习是目前广泛使用的处理小样本学习问题的方法。

元学习(meta learning或learning to learn),也称学会学习,元学习算法能够在学习不同任务的过程中积累经验,从而使得模型能够快速适应新任务。

元学习与一般的监督学习的区别:

一般的监督学习是在训练集上训练出一个函数映射 f,这个函数可以识别出哪张图片是狗,哪张是猫。输入是某张图片,输出是标签。

元学习算法则是让模型学会学习,即在训练集上学习出一个函数 FF 可以自动学习出一个函数 f^*,他可以分辨出哪张图片是狗,哪张是猫。F 的输入是一个个的图片集合,输出是函数 f^*, f^* 的输入是某张图片,输出是标签。 

可以这么理解使用元学习算法的小样本学习任务:我有一个数据集{大象,老虎,狮子},小样本学习并非是让模型识别出哪个是老虎、大象或者狮子,而是学习出每个类别之间的差异,以便在新的数据集(比如:{汽车、电视、沙发、鼠标})中更好的分类。

小样本学习图片分类的基本思想

为了更形式化评估元学习算法,在分类问题上,元学习的数据形式和一般监督学习的数据形式也有所不同,最小的数据点不再是一张图片,而是一个一个的小任务。每个小任务中有 N 个类别,每个类别有 K 张图片,我们称这些任务为N-way K-shot图像分类任务,一共有 episodes 个小任务。当K值很小时(一般K<10),该任务就是小样本图像分类任务了。当K=1时,该任务即为单样本图像分类任务。

除此之外我们还需要知道两个重要概念:

  • 支持集(Support Set):相当于每个小任务中的训练集,包含N个分类标签,每个标签有K张图片。
  • 查询集(Query Set):相当于每个小任务中的测试集,包含Q张未分类的图片。

如下图,为一个3-way 2-shot图像分类任务,蓝色板块是支持集,绿色的是查询集:

 注意,对于元学习而言,上图的3-way 2-shot图像分类任务只是一个数据点,完整的数据集及其训练集-测试集划分如下图所示:

  

元学习流派

Black Box / Model-based
为小数据集场景专门制定一个能够快速变化参数的模型。代表作有:MANN,MetaNet等。

Optimization Based
通过让模型快速优化自己的参数来实现小样本学习。代表作有:MAML,NAIL,Reptile等。

Metric Based(基于度量的方法)
也是目前主流的方法,通过学习一个Encoder,将数据映射到一个表征空间,然后使用无参的Decoder来进行分类。代表作有Matching Network,Prototypical Network等。

Prototypical Network基本原理

以一个episode为例,其中包含 {狗,虫子,鸟} 三个类别的图片。

1、首先,对支持集中的每张图片使用编码器  f_\varphi (x_i) 进行信息提取,学习到每张图片的Embedding编码表示。(编码器可以选择常规的卷积操作、resnet系列、vit等等)

2、 然后对支持集的每个类别下的Embeddings做均值处理,得到每个类别的原型表示(class Prototype)。

  3、对查询集中的图片进行分类。首先使用编码器 f_\varphi (x_i) 将查询集图片进行编码,得到该图片的Embedding向量表示。

 4、然后拿着这个Embedding表示和类别原型进行相似度计算,也就是无参的解码过程。(相似度计算的方式很多,可以是欧氏距离或者余弦相似度等)

 

 5、计算完相似度后,往往还需要使用softmax将相似度激活成概率分布。

最终得到查询集图片的分类标签,然后和真实值标签做交叉熵loss,然后梯度反向传播即可完成一个episode的训练。

Prototypical Network算法描述

1、假设原始数据集为D,对于每一个episode,包含一个支持集和一个查询集,即D_{episode}=D_{support}\cup D_{query}=\left \{ s_i \right \}^{n_s}_{i=1}\cup \left \{ q_i \right \}^{n_q}_{i=1}  。

实现方法就是在原始数据集 D 中随机选取N个类别,每个类别选取K张图片,构成支持集,选取Q张图片,构成查询集,这样就组成了一个episode的小数据集。以此类推,构造 episodes 个小数据集。

2、 对每张图片,利用Encoder进行特征提取,即

 hs_i=f_\varphi (s_i) 

 hq_i=f_\varphi (q_i) 

3、计算出支持集中的每个类别的原型(prototype),即

p_{c_j}=\sum _{\left \{ i|l_{s_i}=c_j \right \}}hs_i

其中,l_{s_i} 表示图片 s_i 的类别标签。

 4、接下来计算每个查询集图片Embedding与每个类别的相似度,即

p(\widehat{l}_{q_i}=c_j)=\frac{exp(sim(hq_i,p_{c_j}))}{\sum ^{|c|}_{k=1}exp(sim(hq_i,p_{c_k}))}

5、训练用的损失函数,公式如下:

L_{q_i} = CrossEntropy(l_{q_i},\widehat{l}_{q_i})

Prototypical Network 的 Pytorch实现

prototypical network是有官方的论文实现的【prototypical-network源码】,而且很多框架里自带原型网络的包,直接调用即可。

但是官方的论文源码比较难看,而且某些场景需要拆解组合时也不方便,因此这里我自己实现了一个精简版的原型网络【我自己的代码复现】

总结

一般来说way与shot准确率的关系如下所示:

 这个很好理解,一个episode中类别(way)越少,就越容易找出图片之间的异同,比如二分类就比十分类容易一些;一个episode中同一个类别的样本(shot)越多,就越容易找出图片之间的异同。

本质上来说,原型网络就是集成学习中的stacking思想。


题外话

一般来说,对于一个稍微有点机器学习基础的小白,要想在一个未知的ML分支领域快速实现一个比较靠谱的实例(比如打一个小样本学习的算法挑战赛,但我从没接触过小样本学习)。总体步骤如下:

  1. 先用搜索引擎搜索几个相关条目,大体了解这个领域是干什么的;
  2. 找一些该领域的大腿、论坛或者订阅号,如sota排名网页、公众号、添加一下文献鸟的关键词等。最好找个人带
  3. 找几篇这个领域最新的综述,大体了解一下该领域各个流派之间的腥风血雨;
  4. 看几个这个领域比较火的代码实例;
  5. 找几篇最新的SOTA,最好带代码的;
  6. 添加自己的想法,落地实现并验证。

参考资料:

综述:

https://arxiv.org/abs/2004.05439

【干货】如何通过元学习解决小样本图像分类任务 - 知乎

Meta Learning(元学习)_哔哩哔哩_bilibili

SOTA:

Few-Shot Classification Leaderboard

https://arxiv.org/abs/1703.05175v2

TechBeat

简单的实例:

https://arxiv.org/abs/1703.05175

元学习——原型网络(Prototypical Networks)_hei653779919的博客-优快云博客_原型网络

元学习—MAML模型Pytorch实现_hei653779919的博客-优快云博客

### 回答1: 原型网络是一种基于距离度量的分类模型,它通过计算输入样本与每个类别的原型之间的距离来进行分类。原型可以是类别的平均值、中心点或代表性样本。该模型在小样本学习中表现出色,因为它不需要大量的训练数据就可以进行分类。 ### 回答2: 原型网络是一种用于处理图像分类任务的深度学习模型。它的灵感来源于人类对物体的认知过程,即将物体与所了解的样本进行比较,然后将其分类为属于哪个类别。原型网络的核心思想就是通过比较输入图像和已知样本之间的相似程度来确定图像的类别。 原型网络的架构非常简单,它由两个主要部分组成:卷积层和比较层。卷积层将输入的图像转换为特征向量,并提取其特征。比较层由多个原型向量组成,每个原型向量代表类别中心,即样本的平均值。当输入图像被传递到比较层时,将计算输入向量与每个原型向量之间的距离,距离越小,图像就越可能属于该类。 与其他深度学习模型相比,原型网络具有许多优点。首先,它的计算效率非常高。由于原型网络只需要计算每个类别的平均值,因此它可以很容易地扩展到具有大量类别的大规模数据集。其次,原型网络的可解释性非常好。由于每个原型向量代表一个类别,因此可以轻松地解释哪些特征被认为是用于识别该类别的最重要的特征。 总之,原型网络是一种非常简单而又有效的图像分类模型。它能够高效地处理大规模数据集,并具有很好的可解释性,这使其成为了一种流行的深度学习模型。 ### 回答3: 原型网络prototypical network)是一种基于原型的学习方法,用于分类任务。它被认为是一种简单而有效的深度学习模型,已经在语音识别、图像分类等领域得到了广泛应用。该模型的基本思想是将类别信息映射到一个低维的原型空间中,从而通过计算距离来进行类别分类原型网络的训练过程可以概括为以下几个步骤:首先,对于每个类别,从训练数据中挑选一些样本作为该类别的原型。这些原型可以是某个样本的平均值、中心点或其他特定的统计量。接着,对于每个样本,计算其到各个原型的距离,并选取距离最小的那个原型作为该样本所属的类别。最后,通过使用交叉熵等损失函数来对模型进行训练,使得原型能够更好地区分不同的类别。 相比于传统的深度学习模型,原型网络具有以下几个优势:首先,由于其使用了低维的原型空间,计算速度较快且存储空间占用较小。其次,模型的可解释性较高,可以直观地对原型进行分析,从而得到更深入的理解。最后,原型网络对于小样本学习具有一定的优势,即在只有少量训练数据的情况下,仍然可以取得较好的分类性能。 总之,原型网络是一种简单而有效的深度学习模型,具有较好的分类性能和可解释性,是当前学术界和工业界研究的热点之一。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值