Learning to Compare: Relation Network for Few-Shot Learning. (学习比较:用于few-shot learning 的关系网络)

探讨了关系网络(Relation Networks,RN)在Few-Shot学习中的应用,RN是一种端到端的网络,能通过学习深度距离度量,对新类别进行分类。该方法不仅提升了Few-Shot学习性能,还能扩展至Zero-Shot学习。RN由Embedding Module和Relation Module组成,通过Episode-based策略训练。

1. 摘要

文章提出了一种概念上简单、灵活、通用的框架用于 few-shot learning 问题。few-shot learning 问题需要分类器必须在每个新类只给出几个样本情况下识别新的类(新的类是指在训练阶段没有见过的类)。文章提出了网络叫做 — 关系网络(Relation Networks,RN),一个端到端(end to end)的网络。在 meta-training 阶段,网络学会学习一个深度距离度量在一个剧集(epoch)内比较少量几张图片,并且每个剧集(epoch)设计来模仿 few-shot learning 的设置。一旦训练好,RN 就可以在不进行任何更新的情况下对来自新的类的样本进行分类,分类过程:将 query 中的样本与 support 中样本计算关系得分(relation score)。RN 除了在 few-shot leanring 问题上提升了性能,而且可以很容易的扩展到 zero-shot leanring(零样本学习) 问题上。广泛的实验证明文章提出的简单的网络提供了一个对于 few-shot leanring 和 zero-shot leanring 统一且有效的方法。

(1)新的类;(2)每个类的训练样本只有一个;(3)学习一个网络映射输入空间到新的空间,比较相似度;
这三个特征满足 few-shot learning,meta-learning 的基本特征,而且可以看成属于 meta-learning 中的度量学习 (metric learning)。

(推荐阅读以下链接,对理解本文很有帮助:https://blog.youkuaiyun.com/weixin_37589575/article/details/92801610

2. 介绍

深度学习已经在计算机视觉领域取得了巨大的成功,然而有监督的学习模型需要大量的有标签的数据和大量的迭代,这严重的限制了对于新类别的可扩展性(对新数据的标记成本),还有一个更根本的问题,就是对新出现的东西(例如新设备)或者稀有种类(例如稀有动物)的实用性,这两种情况下,大量标记的样本根本不存在。相比之下,人类可以在具有很少的监督信息的情况下甚至没有监督信息的情况下快速学习新的概念(分别对应 few-shot leanring 和 zero-shot leanring)。例如一个小孩可以从书本上的单个图片推广归纳 “斑马” 的概念,甚至在可以听到一个类似于“条纹马”的描述下都可以学习到“斑马”的概念。然而我们最好的深度学习 (Deep Learning, DL) 模型却需要成百上千个样本。这个机器智能和人类智能的差距,推动鼓励我们来关注 few-shot learning 和 one-shot leanring 问题。数据增强技术和正则化技术可以缓解在少量数据集情况下的 过拟合 (overfitting) 问题,但是并没有很好的解决它。因此开始将 meta-learning 用于 few-shot learning 问题。

这种方法将训练分解成了一个辅助的 meta training 阶段,其中学习可以迁移的知识用于识别新的类。这种可以迁移的知识主要可以分为:

  1. 好的初始化。典型的有 MAML,《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》:https://www.cs.utexas.edu/~sniekum/classes/RL-F17/papers/Meta.pdf
  2. embeding 函数。典型的有 Matching Networks: https://blog.youkuaiyun.com/weixin_37589575/article/details/93844850,Prototypical Networks:https://blog.youkuaiyun.com/weixin_37589575/article/details/92768668
  3. 优化策略:《Optimization as a Model for Few-Shot Learning》:https://openreview.net/pdf?id=rJY0-Kcll.

相对应的,在新的需要识别新类的 meta-testing 的问题任务上:

  1. 基于学习到的初始化对网络进行微调;
  2. 直接前向传播而无需更新参数
  3. 基于学习到的优化策略对网络进行微调;

对于 zero-shot leanring,通常我们已有的信息不是“类别”这样的标签信息,可能是一个关于“类别”的一个描述信息,例如把不告诉你这个图片中的动物是“斑马”,而是告诉你图片中的动物是一种“条纹马”。

本文的方法可以分类为上述的第二种方法,但是最大的一点不同是,RN 不仅学习了一种 embedding 函数,还自己学习了一个度量。先前的方法中度量都是手动人为预先定义的,例如 Siamese Networks 和 Matching Networks 使用的余弦距离,以及 Prototypical Networks 中的欧氏距离。文章进一步学习了一个可迁移(而不是手动人为定义的)的度量来比较图片之间的关系(Relation)。这也可以看成学习一个非线性的比较器,而不是一个固定的线性的比较器,文中 Section 5 简单分析了一下这样的好处。这里又出现了归纳偏差:多次的 Embedding 和 Relation 的非线性学习阶段,这可以更容易学习到一个泛化能力强的解决方案。

具体而言,RN 包含两个主要的模块:

  1. Embedding Module:为样本生成一个表示(representation)。
  2. Relation Module:确定需要比较的样本是否来自于同一类。

训练的策略还是标准的在 meta-learning 中使用最多的基于剧集的策略(Episode-based strategy)。

3. 相关工作

看过我之前博客的朋友都应该知道,我一般不介绍相关工作的,但是这一篇论文的相关工作真的写得太好了。另外本文的作者 Flood Sung 可以在知乎直接搜索找到,他的文章特别的棒。

文章将 Meta-learning (Learning to learn)方法大致分为三类:

  1. 学习微调:MAML(《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》) 是这类方法的范例之一。MAML 的思想是学习一个良好初始化参数 (initialization parameter),这个初始化参数在遇到新的问题时,只需要使用少量的样本 (few-shot learning) 进行几步梯度下降就可以取得很好地效果( 参见后续博客 )。另一个典型是《Optimization as a Model for Few-Shot Learning》,他不仅关注于初始化,还训练了一个基于 LSTM 的优化器 (optimizer) 来帮助微调。
  2. 基于 RNN 的记忆存储:最直观的方法,使用基于 RNN 的技术记忆先前 task 中的表示等,这种表示将有助于学习新的 task。可参考《Meta networks》和 《Meta-learning with memory-augmented neural networks.》
  3. Embedding 和 度量学习(Metric Learning):主要可以参考《Learning a Similarity Metric Discriminatively, with Application to Face Verification.》,《Siamese neural networks for one-shot image recognition》,《Siamese neural networks for one-shot image recognition》,《Matching networks for one shot learning》,《Prototypical Networks for Few-shot Learning》,《Learning to Compare: Relation Network for Few-Shot Learning》。
    核心思想:学习一个 embedding 函数,将输入空间(例如图片)映射到一个新的嵌入空间,在嵌入空间中有一个相似性度量来区分不同类。我们的先验知识就是这个 embedding 函数,在遇到新的 task 的时候,只将需要分类的样本点用这个 embedding 函数映射到嵌入空间里面,使用相似性度量比较进行分类。

这里加一下我自己的一些理解:

  1. 基于 RNN 的记忆 (RNN Memory Based) 有两个关键问题,一个是这种方法经常会加一个外部存储来记忆,另一个是对模型进行了限制 (RNN),这可能会在一定程度上阻碍其发展和应用。
  2. 学习微调 (Learning to Fine-Tune) 的方法需要在新的 task 上面进行微调,也正是由于需要新的 task 中 support set 中有样本来进行微调,目前我个人还没看到这种方法用于 zero-shot learning 的问题上,但是在《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》的作者 Chelsea Finn 的博士论文《Learning to Learn with Gradients》中给出了 MAML 的理论证明,并且获得了 2018 ACM 最佳博士论文奖,还有一点就是 MAML 可以用于强化学习,另外两种方法多用于分类问题。链接:https://mp.weixin.qq.com/s/AdlwI-nbVlDWCj0o5LR7Sw
  3. 度量学习 (Metric Learning),和学习微调 (Learning to Fine-Tune) 的方法一样不对模型进行任何限制,并且可以用于 zero-shot learning 问题。虽然效果比较理想但是现在好像多用于分类任务并且可能缺乏一些理论上的证明,比如相似性度量是基于余弦距离还是欧式距离亦或是其他?为什么是这个距离?(因为 embedding 函数是一个神经网络,可解释性差,导致无法很好解释新的 embedding 空间),虽然本文将两个需要比较的 embedding 又送到一个神经网络(而不是人为手动选择相似性度量)来计算相似性得分,但是同样缺乏很好地理论证明。

4. 方法论

4.1 问题定义

在这里插入图片描述
这里为了方便读者和自己理解,对原论文的这一部分做了修改,用了最 meta-leaning 的解释(其实作者在知乎文章里面也是用的我这里的描述,不知道为什么在论文里面换了一种描述方式(可能为了描述更加简单),其实本质是一样的)。

在 few-shot learning 中有一个术语叫做 N N N-way K K K-shot 问题,简单的说就是我们需要分类的样本属于 N N N 个类中一种,但是我们每个类训练集中的样本只有 K K K 个,即一共只有 N ∗ K N * K N<

### 关于 Relation NetworkFew-Shot Learning 复现 #### 使用 PyTorch 实现 Relation Network 以下是基于 PyTorch 的 Relation Network 实现代码示例: ```python import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms, datasets class CNNEncoder(nn.Module): """CNN Encoder""" def __init__(self): super(CNNEncoder, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)) self.layer2 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)) self.layer3 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU()) self.layer4 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU()) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) return out.view(out.size(0), -1) class RelationNetwork(nn.Module): """Relation Network""" def __init__(self, input_size, hidden_size): super(RelationNetwork, self).__init__() self.fc1 = nn.Linear(input_size*2, hidden_size) self.fc2 = nn.Linear(hidden_size, 1) def forward(self, x): out = torch.relu(self.fc1(x)) out = torch.sigmoid(self.fc2(out)) return out def train(model_encoder, model_relation, optimizer, criterion, support_set, query_set, device='cpu'): model_encoder.train() model_relation.train() support_features = model_encoder(support_set.to(device)) # Extract features from the support set query_features = model_encoder(query_set.to(device)) # Extract features from the query set relations = [] for i in range(len(support_features)): pair_feature = torch.cat([support_features[i], query_features], dim=-1) # Concatenate pairs of features relation_score = model_relation(pair_feature).view(-1) # Compute similarity score relations.append(relation_score) output = torch.stack(relations).squeeze() # Stack all scores into a tensor target = torch.zeros(output.shape[0]).to(device) # Create ground truth labels loss = criterion(output, target) # Calculate loss optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() ``` 上述代码实现了 `CNNEncoder` 和 `RelationNetwork`,并定义了一个简单的训练函数。 --- #### TensorFlow 实现 Relation Network 下面是基于 TensorFlow 的 Relation Network 实现代码示例: ```python import tensorflow as tf from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Flatten, Dense, ReLU, Input, concatenate from tensorflow.keras.models import Model def create_cnn_encoder(): inputs = Input(shape=(84, 84, 3)) x = Conv2D(filters=64, kernel_size=3)(inputs) x = BatchNormalization()(x) x = ReLU()(x) x = MaxPooling2D(pool_size=2)(x) x = Conv2D(filters=64, kernel_size=3)(x) x = BatchNormalization()(x) x = ReLU()(x) x = MaxPooling2D(pool_size=2)(x) x = Conv2D(filters=64, kernel_size=3, padding="same")(x) x = BatchNormalization()(x) x = ReLU()(x) x = Conv2D(filters=64, kernel_size=3, padding="same")(x) x = BatchNormalization()(x) x = ReLU()(x) outputs = Flatten()(x) encoder_model = Model(inputs=inputs, outputs=outputs, name="cnn_encoder") return encoder_model def create_relation_network(input_dim, hidden_units): feature_a = Input(shape=input_dim) feature_b = Input(shape=input_dim) concatenated = concatenate([feature_a, feature_b]) x = Dense(units=hidden_units, activation="relu")(concatenated) x = Dense(units=1, activation="sigmoid")(x) relation_model = Model(inputs=[feature_a, feature_b], outputs=x, name="relation_network") return relation_model # Example usage encoder = create_cnn_encoder() relation_net = create_relation_network(input_dim=1600, hidden_units=8) ``` 此代码展示了如何使用 Keras API 构建 CNN 编码器和关系网络模型。 --- #### 方法概述 Relation Network 是一种通过学习嵌入空间中的相似性来解决小样本分类问题的有效方法[^2]。其核心思想在于设计一个深度非线性距离度量函数,该函数能够捕捉查询样本和支持集之间的关系。相比其他元学习方法,Relation Network 更加简洁高效,在多种任务上表现出色。 为了验证其实效性,可以采用 Omniglot 或 Mini-ImageNet 数据集进行实验,并按照 N-way-K-shot 设置划分支持集与查询集。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值