人类在学习过程中会不断总结经验,在完成许多任务的学习后,当面对新任务时,可以使用积累的经验迅速适应新的任务。对比元学习,人类学习不同任务的过程就是元学习训练的过程,而最后人类所积累到的经验就是元学习所得到的knowledge,即对于机器模型而言的经验。
所以元学习的重点就是使模型学习好多好多的任务,即Task。这些Task可以是相同的任务,也可以是不同的任务,可以都是分类问题,也可以分类中参杂着回归等等。只要模型在解决这些Task时,能够学习到我们期望它所需的知识,这些Task可以是任何配置。既然是Task,那么它就应该是一个个单独的任务,所以其中既包括训练数据又包括测试数据。
最简单的Task格式是这样的(取台大达李弘毅老师的PPT):
元学习最后得到是knowledge,也就是经验。在不同的网络架构中这个knowledge可以有不同的体现,它可以是一个预测不同模型参数的能力,或者可以是一个embedding函数,也可以是一个有着效果不错初始化参数。
依据实现方式的不同,目前最常见的元学习方式包括三种:
- 基于RNN等模式的记忆存储
- 基于参数初始化方式的MAML,Reptile等
- 基于度量学习的Matching Network,Prototypical Network等
对于第一种 “基于RNN等模式的记忆存储” 的元学习方式不是很了解,直接从第二种开始尬!!
基于参数初始化方式的MAML,Reptile等
这种方式主要就是在寻找一套参数,当使用这套参数解决一个新任务时,只需要少量的样本或者几次梯度下降就可以在新任务上取得不错的效果。
MAML
其中MAML属于这一块研究的开山之作,对于MAML而言,生成参数的网络和Task中使用的网络是一致的。其主要思路是:
- 当面对不同的Task,“MAML”会为这些任务生成一套初始化参数
- 每个Task将得到的初始化参数在自己的Support Set计算几次更新后,取平均方向得到更新后的参数,将更新后的参数作用在Query Set上计算损失,这个损失将回传给“MAML”网络,来指导下次参数的生成。如果这个损失很小,说明“MAML”生成的初始化参数,面对一个新的Task,只需几步调整,就能在测试集取得不错的效果,反之就不求行!!

每个Task的损失回传过程中会涉及Query Loss以及Support Loss的链式,所以本质上是一个对损失的二阶导数,优化实现过程中不是很容易实现,所以在MAML的实现,直接使用Query Loss来进行参数更新。如下图:

基于度量学习的Matching Network,Prototypical Network等
基于度量学习的元学习方式主要在学习一个Embedding函数,这个Embedding函数就代表元知识。这种方式大多在Few-Shot 中解决小样本的分类问题,使用Embedding函数将每个Task中的样本映射到指定空间中,然后在此空间中使用各种距离完成分类。
Siamese Network

Matching network

Prototypical networks

- 将相同类别的数据映射为一个“原型”
- 计算测试样本与各个原型之间的距离
- 选取距离最小的作为预测类别
Relation Network

- 使用f提取训练集与测试样本的特征:f(xi),f(xj)
- 将f(xi),f(xj)拼接
- 使用g产生一个(0,1)间的标量来预测f(xi),f(xj)的相似度
参数预测
对于小样本而言,思路就是,在训练集上,获得每个类在大量样本下训练得到的参数W,同时获得对应类在小样本下训练得到的参数W0,所以获得大量的(W0,W),然后学习一个映射函数T:W* = T(W0),在T训练完毕后,当面对新的小样本时,直接使用T将W0映射到W*,完成模型**
Learning to Learn: Model Regression Networks for Easy Small Sample Learning
Meta-Learning to Detect Rare Objects
- 将物体识别分为两部分(1)category-agnostic(类别无关模块)(2)category-specific(类别分类模块)
- 类别无关模块,对于base和novel类是通用的,同时使用一个映射函数将base的category-specific参数映射到novel类别中
Low-Shot Learning with Imprinted Weights
- 具体分析了度量损失和softmax损失的一致性
- 使用imprinting预测新类的权重
Dynamic Few-Shot Visual Learning without Forgetting
- 首先训练得到Wbase完成Cbase的分类,对于新的样本,将Wbase和小样本数据集输入到generator完成Wnovel的预测
- Wnovel预测过程中,会使用attention筛选Wbase