元学习2之原型网络prototypical-networks for few-shot learning回顾

原型网络是一种小样本学习方法,通过计算类别的平均特征(类中心)来表示类别,提高模型对噪声的鲁棒性。算法涉及N-wayK-shot任务,使用度量如欧式距离来决定样本归属。与MatchNetworks的主要区别在于度量方式和网络结构,原型网络在one-shot和few-shot场景下表现不同。实验表明,使用余弦相似度通常能取得更好的效果。

1.论文和代码

2.简介

小样本学习不仅仅训练和测试集的样本没有交集,类别也是没有交集的
论文一共做了两个任务:1.小样本;2. 零样本。

在这里插入图片描述

如图左边:
三种颜色代表三个类别(3-way), c 1 , c 2 , c 3 c_1,c_2,c_3 c1,c2,c3分别是三个类别的中心,用类中心表示类别的好处就是某个类别中某些数据存在一些噪声,用类中心来表示这个特征,比较robust,类中心对抗噪声的能力比单个样本生存的这个特征要强很多。
类中心(prototypes)用各类别下所有样本特征的平均值来计算的: v c = 1 ∣ S c ∣ ∑ ( x i , y i ) ∈ S c f θ ( x i ) v_c=\frac{1}{|S_c|}\sum_{(x_i,y_i)\in S_c} f_{\theta}(x_i) vc=Sc1(xi,yi)Scfθ(xi)
X X X为是Query set,要确定它的类别,判断它的类别需要它分别计算与 c 1 , c 2 , c 3 c_1,c_2,c_3 c1,c2,c3之间的距离,哪一个距离越小,与哪个类别的相似度就越大,就归为哪一类。

基于度量的元学习,度量类别与类别之间的距离的一些指标(欧式距离或者余弦距离等)。

3.算法流程

在这里插入图片描述

3.1 名词解释及符号定义

  • episode:表示一个N-way K-shot任务,N为类别数,K为每个类别的数量。 N C N_C NC为每个任务中类别的数量, N S + N Q N_S+N_Q NS+NQ每个类别样本的数量。这里就是 N C N_C NC-way N S N_S NS-shot N Q N_Q N
Prototypical networks for few-shot learning 是少样本学习领域的重要研究,其核心论文《Prototypical Networks for Few - Shot Learning》提出了原型网络用于解决少样本学习问题。 原型网络的基本原理是在嵌入空间中为每个类别学习一个“原型”(prototype),这个原型是该类别所有样本嵌入向量的均值。在测试阶段,通过计算测试样本与各个类别原型之间的距离来进行分类,将测试样本分类到距离最近的原型所代表的类别中。 以下是一个简单的 PyTorch 代码示例,用于说明原型网络的基本流程: ```python import torch import torch.nn as nn import torch.optim as optim # 定义一个简单的嵌入网络 class EmbeddingNetwork(nn.Module): def __init__(self): super(EmbeddingNetwork, self).__init__() self.fc1 = nn.Linear(10, 20) self.relu = nn.ReLU() self.fc2 = nn.Linear(20, 10) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x # 计算原型 def compute_prototypes(support_set, labels, num_classes): prototypes = [] for c in range(num_classes): class_indices = (labels == c).nonzero(as_tuple=True)[0] class_samples = support_set[class_indices] prototype = torch.mean(class_samples, dim=0) prototypes.append(prototype) return torch.stack(prototypes) # 分类函数 def classify(query_set, prototypes): distances = torch.cdist(query_set, prototypes) predictions = torch.argmin(distances, dim=1) return predictions # 初始化网络 embedding_net = EmbeddingNetwork() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(embedding_net.parameters(), lr=0.001) # 模拟数据 support_set = torch.randn(20, 10) # 20 个样本,每个样本 10 维 support_labels = torch.randint(0, 5, (20,)) # 5 个类别 query_set = torch.randn(10, 10) # 10 个查询样本 # 训练过程 for epoch in range(100): optimizer.zero_grad() support_embeddings = embedding_net(support_set) prototypes = compute_prototypes(support_embeddings, support_labels, 5) query_embeddings = embedding_net(query_set) predictions = classify(query_embeddings, prototypes) # 这里假设我们有查询样本的真实标签 query_labels = torch.randint(0, 5, (10,)) loss = criterion(predictions.float(), query_labels.float()) loss.backward() optimizer.step() ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值