Prototypical Networks for Few-shot Learning.(用于少样本分类的原型网络)

**

Prototypical Networks for Few-shot Learning.(用于少样本分类的原型网络)

**

摘要

文章提出了一种用于少样本分类的原型网络,其中分类器必须可以推广(泛化)到在训练集里面没有见过的新的类别,并且每个新的类别只有很少一部分样本。原型网络学习一个度量空间,执行分类只需要简单的计算到每个类的原型表示的距离。与其他方法的主要一点不同是原型网络反映了一种在数据集不足情况下的归纳偏差。最后一点是原型网络可以用于zero-shot learning(零样本学习)。

简单的设计选择带来了实质性的改进

本质思想

使用一个神经网络学习一个 embedding 将输入映射到一个映射空间里面,而每个类的原型就是简单的这个类的 support set 所有样本(support set 的定义参见后文)的均值。分类的执行:将需要分类的样本用学习到的映射函数映射到嵌入空间里面,然后离他最近的原型就是我们的预测类。

介绍

  1. few-shot learning 问题的本质之一是样本很少,一般每个类的样本少于20个。另外一个本质是这些类是全新的类。few-shot learning 问题的分类器必须调整以适应新的类,并且新的类中每个类的样本很少。传统的方法是在新的数据上重新训练新的模型,但是由于样本太少,往往会导致过拟合 (overfitting)。但是**我们人类是可以做到 few-shot learning **的,因此迈向普世的人工智能,few-shot leanring 问题是一个重要问题。
  2. 文章假设:由于样本有限,我们的分类器会有一个归纳偏差 (inductive bias)。 原型网络基于这样的想法:存在一种 embedding,使得其中每个类的点(可能多个)聚集在每个类的单个原型表示周围。
  3. 方法概述:使用一个神经网络学习一个 embedding 将输入映射到一个映射空间里面,而每个类的原型就是简单的这个类的 support set 所有样本(support set 的定义参见后文)的均值。分类的执行:将需要分类的样本用学习到的映射函数映射到嵌入空间里面,然后离他最近的原型就是我们的预测类。zero-shot learning 类似,只不过每个类没有标记 (label),而是一个关于类的高级描述 (high-level description of the class)。

方法

  1. 符号定义:support set 是我们具有的 N N N 个标记的样本 S = { ( x 1 , y 1 ) , … , ( x N , y N ) } S=\left\{\left(\mathbf{x}_{1}, y_{1}\right), \ldots,\left(\mathbf{x}_{N}, y_{N}\right)\right\} S={ (x1,y1),,(xN,yN)},其中 x i ∈ R D \mathbf{x}_{i} \in \mathbb{R}^{D} xiRD 是一个 D D D维的向量, y i ∈ { 1 , … , K } y_{i} \in\{1, \ldots, K\} yi{ 1,,K} 使对应的类, S k S_{k} Sk 表示类别为 k k k 的样本的集合。
  2. 模型:首先一个映射函数 (embedding): f ϕ : R D → R M f_{\phi} : \mathbb{R}^{D} \rightarrow \mathbb{R}^{M} fϕ:RDRM 参数为 ϕ \phi ϕ,将输入空间映射到嵌入空间 (嵌入空间的维度为 M M M),然后计算每个类的原型: c k = 1 ∣ S k ∣ ∑ ( x i , y i ) ∈ S k f ϕ ( x i ) \mathbf{c}_{k}=\frac{1}{\left|S_{k}\right|} \sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in S_{k}} f_{\phi}\left(\mathbf{x}_{i}\right) ck=Sk1(xi,yi)<
### 关于Prototypical Networks for Few-Shot Learning论文 Prototypical Networks 是一种针对小样本学习(Few-Shot Learning)设计的方法,旨在通过构建类别原型来完成新类别的快速识别[^1]。该方法的核心思想是在嵌入空间中为每个类别定义一个原型向量,并利用这些原型向量计算输入样本与各个类别的相似度。 #### 论文下载方式 Prototypical Networks 的原始论文《Prototypical Networks for Few-shot Learning》可以在以下链接找到并下载: - 原始论文地址:[https://arxiv.org/abs/1703.05175](https://arxiv.org/abs/1703.05175) 如果无法直接访问上述链接,可以尝试通过其他学术资源网站获取PDF版本,例如Google Scholar 或者 ResearchGate。此外,一些开源项目也提供了对该论文的解读和实现细节说明[^2][^3]。 #### 可视化的实现 对于可视化部分,作者展示了在嵌入空间中的 t-SNE 图像,其中类别原型用黑色表示,而错误分类的字符则用红色高亮显示[^4]。要实现类似的可视化效果,可以通过以下步骤: 1. **提取嵌入特征** 使用训练好的模型对测试数据进行前向传播,得到每张图像对应的嵌入向量。 2. **应用t-SNE降维** 利用 `sklearn.manifold.TSNE` 将高维嵌入向量投影到二维或三维空间以便绘制图形。 3. **绘图展示** 使用 Matplotlib 或 Seaborn 绘制散点图,在图中标记出不同类别的原型位置以及误分类的数据点。 以下是基于PyTorch实现的一个简单代码示例: ```python import torch from sklearn.manifold import TSNE import matplotlib.pyplot as plt def visualize_embeddings(embeddings, labels, prototypes): """ embeddings: 测试样本的嵌入向量 (NxD) labels: 对应的真实标签 (Nx1) prototypes: 类别原型向量 (KxD) """ all_data = torch.cat([embeddings, prototypes], dim=0).cpu().numpy() tsne = TSNE(n_components=2, random_state=42) embedded_data = tsne.fit_transform(all_data) # 提取样本点和原型点的位置 sample_points = embedded_data[:len(embeddings)] prototype_points = embedded_data[len(embeddings):] # 绘制散点图 plt.figure(figsize=(8, 6)) unique_labels = set(labels.cpu().numpy()) colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels))) for i, color in zip(unique_labels, colors): idx = np.where(labels == i)[0] plt.scatter(sample_points[idx, 0], sample_points[idx, 1], c=color, label=f'Class {i}') # 添加原型点标记 plt.scatter(prototype_points[:, 0], prototype_points[:, 1], marker='*', s=200, c='black', label='Prototypes') plt.legend() plt.title('t-SNE Visualization of Embedding Space') plt.show() ``` 此函数接受嵌入向量、真实标签以及类别原型作为输入参数,并生成相应的t-SNE可视化图表。 --- ###
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值