论文笔记:Prototypical Networks for Few-shot Learning

本文介绍Prototypical Networks,它通过神经网络将样本映射到同一空间,并利用欧几里得距离找到类别的原型。在训练时,目标是使测试样本接近其类别原型,远离其他类别。与Match Network相比,P-net使用欧几里得距离而非余弦距离,且在网络结构上有优化。补充内容探讨了Bregman divergence在选择距离度量中的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Prototypical Networks for Few-shot Learning(用于小样本学习的原型网络)

论文中心思想:

通过神经网络学会一个“好的”映射,将各个样本投影到同一空间中,对于每种类型的样本提取他们的中心点(mean)作为原型(prototype)。使用欧几里得距离作为距离度量,训练使得测试样本到自己类别原型的距离越近越好,到其他类别原型的距离越远越好。测试时,通过对到每类原型的距离做sofmax获得测试样本类别。
原型网络原理示意

算法主要流程

总结分析:

本文提出的的Prototypical Networks(P-net)思想与match network(M-net)十分相似,但也有几个不同点:1.使用了不同的距离度量方式,M-net中是cosine度量距离,P-net中使用的是属于布雷格曼散度(详见论文)的欧几里得距离。2.二者在few-shot的场景下不同,在one-shot时等价(one-shot时取得的原型就是支持集中的样本)3.网络结构上,P-net相比M-net将编码层和分类层合一,参数更少,训练更加方便。

补充:

关于论文中Bregman divergence的理解,可参考知乎上一个大佬的回答:如何理解Bregman divergence? - 覃含章的回答 - 知乎
https://www.zhihu.com/question

### 关于Prototypical Networks for Few-Shot Learning论文 Prototypical Networks 是一种针对小样本学习(Few-Shot Learning)设计的方法,旨在通过构建类别原型来完成新类别的快速识别[^1]。该方法的核心思想是在嵌入空间中为每个类别定义一个原型向量,并利用这些原型向量计算输入样本与各个类别的相似度。 #### 论文下载方式 Prototypical Networks 的原始论文Prototypical Networks for Few-shot Learning》可以在以下链接找到并下载: - 原始论文地址:[https://arxiv.org/abs/1703.05175](https://arxiv.org/abs/1703.05175) 如果无法直接访问上述链接,可以尝试通过其他学术资源网站获取PDF版本,例如Google Scholar 或者 ResearchGate。此外,一些开源项目也提供了对该论文的解读和实现细节说明[^2][^3]。 #### 可视化的实现 对于可视化部分,作者展示了在嵌入空间中的 t-SNE 图像,其中类别原型用黑色表示,而错误分类的字符则用红色高亮显示[^4]。要实现类似的可视化效果,可以通过以下步骤: 1. **提取嵌入特征** 使用训练好的模型对测试数据进行前向传播,得到每张图像对应的嵌入向量。 2. **应用t-SNE降维** 利用 `sklearn.manifold.TSNE` 将高维嵌入向量投影到二维或三维空间以便绘制图形。 3. **绘图展示** 使用 Matplotlib 或 Seaborn 绘制散点图,在图中标记出不同类别的原型位置以及误分类的数据点。 以下是基于PyTorch实现的一个简单代码示例: ```python import torch from sklearn.manifold import TSNE import matplotlib.pyplot as plt def visualize_embeddings(embeddings, labels, prototypes): """ embeddings: 测试样本的嵌入向量 (NxD) labels: 对应的真实标签 (Nx1) prototypes: 类别原型向量 (KxD) """ all_data = torch.cat([embeddings, prototypes], dim=0).cpu().numpy() tsne = TSNE(n_components=2, random_state=42) embedded_data = tsne.fit_transform(all_data) # 提取样本点和原型点的位置 sample_points = embedded_data[:len(embeddings)] prototype_points = embedded_data[len(embeddings):] # 绘制散点图 plt.figure(figsize=(8, 6)) unique_labels = set(labels.cpu().numpy()) colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels))) for i, color in zip(unique_labels, colors): idx = np.where(labels == i)[0] plt.scatter(sample_points[idx, 0], sample_points[idx, 1], c=color, label=f'Class {i}') # 添加原型点标记 plt.scatter(prototype_points[:, 0], prototype_points[:, 1], marker='*', s=200, c='black', label='Prototypes') plt.legend() plt.title('t-SNE Visualization of Embedding Space') plt.show() ``` 此函数接受嵌入向量、真实标签以及类别原型作为输入参数,并生成相应的t-SNE可视化图表。 --- ###
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值