SpanProto: A Two-stage Span-based Prototypical Network forFew-shot Named Entity Recognition

文章介绍了一种新的方法SpanProto,针对Few-shotner问题,通过span-based的两阶段学习,关注实体边界信息,并使用边际损失减少假阳性的错误。研究者提出原型学习和基于边界的分类策略,以提高在少量标注数据下的实体识别性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文链接:

https://aclanthology.org/2022.emnlp-main.227.pdf

EMNLP 2022

介绍

        问题

        Few-shot ner之前的方法都是基于对token进行分类,忽略了实体边界的信息,同时大量的负样本(non-entity)也会影响模型的性能。

        IDEA 

         作者提出了一个span-based的two-stage原型网络(SpanProto)来解决Few-shot ner问题;在span提取阶段,将序列tag转化为一个全局的边界矩阵,使得模型能够注意到准确的边界信息。在分类阶段,对每个实体类别计算原型(prototype)embedding,并利用原型学习在语义空间中调整span的表征。另外,为了解决false positive问题,作者设计了一种margin-based损失,来增大false positive和所有原型之间的语义距离。

方法

        模型的整体结构如下所示:

         一个训练步的数据\varepsilon _{train}=(S_{train},Q_{train},T_{train}) \in D_{train},分别表示support set、query set和实体类型集合。对于每个样本\left ( X,M,y \right )\in S_{train}\cup Q_{train}X=\left \{ x_{i} \right \}_{i=1}^{L}表示有L个token的输入序列,M=\left \{ (s_{j}. e(j)) \right \}_{j=1}^{L{}'}表示句子X中的span set, y表示X中的实体集。

Span Extractor

        该模块用于生成所有的候选span,在每个训练步中,使用support 样本对span extractor进行训练。使用BERT来获得上下文embedding H,送入到span extractor中获取边界信息。

        具体的,对于每个span (xi,xj),先计算每个token xi的q和k,然后使用分数函数来计算(xi,xj)是实体span的边界概率,最后基于标记的span,对每一个support sentence生成一个全局边界矩阵。

        对于该模块,设计了一个基于span的交叉熵损失来学习边界信息(也就是使真实span的分数最大):

        对于训练集中的query 样本,获得span extractor生成的预测全局边界矩阵,将高于阈值的候选span作为预测结果。

Mention Classifier

         该模块主要包括两部分:原型学习和基于边界的分类;

        引入mention classifier 来为query样本中生成的每个span分配预定义好的实体类别,并利用原学习引导模型学习每个span的语义表征,使其能够迁移到新领域,同时提出基于边界的学习目标来解决false positive从而提高模型性能。

Prototypical Learning

        计算support数据中同一实体类别所有span表征的平均,作为该类别的原型向量ct,对于后续span的分类就计算该span与哪一个原型向量更近,即将span与原型向量之间的距离作为衡量span属于哪一类别的依据。

        给定训练集\varepsilon _{train}=(S_{train},Q_{train},T_{train}),通过求S_{train}中同一实体类别中所有span表征的平均来得到该类别的原型向量ct:

        其中Kt表示实体类型t的span数量,u_{j} = h_{sj} + h_{ej}表示span边界表征。

        该模块的损失为:

Margin-based Learning

        span extractor提取的span中不是ground truth的都视为false positive,如图中的“Italian”、“2011”,但这些span并不是非实体,只是在这个数据集中没有对应的标签,也就是说这些span的类别在该数据集中是不知道的,因此就要使这些span在语义空间上尽可能的远离现有的原型类别。

        其中u-表示这些span的边界表征。在基于边界的学习中,从所有的原型区域中抽取false positive span,从而获得噪声感知模型。

Training and Evaluation of SpanProto

        该框架的训练验证步骤如下所示:

        在每个训练步中,从Dtrain中随机抽取一个episode 数据,计算每个support 样本中span的loss(3-8行)训练span etractor,然后每个query 样本进行预测得到候选span,使用原型损失和基于边界损失来训练原型向量(10-15)。整个训练一共包括T步,先对span extractor进行T{}'的预训练,再对span extractor和分类器进行联合训练。

        进行推理时,计算span extractor提取的所有候选span 与每个原型的距离,选择最近的类别作为预测类别(在预测时,会删掉与原型之间距离大于r的span)。

实验

数据集介绍

         首先N-way K-shot表示N种标签,每个类别只有K个样本,作者使用了Few-NERD和CrossNER这两个数据集。

        Few-NERD 标注了8种粗粒度实体类型和66种细粒度实体类型(也就是大类别、小类别),包括两种设置:1)Intra:按粗粒度的实体进行分类,也就是训练集、测试集和验证集中包含不同的粗粒度实体类别;2)Inter:按细粒度进行划分,每个粗粒度类中,随机选择60%的细粒度实体作为训练集,随机挑选20%分别作为验证集和测试集。

        CrossNER包括4个不同的领域的ner数据集,随机选择两个数据集作为训练集,剩下的分别作为测试集和验证集。

对比实验

         作者在不同数据集设置下进行了对比实验,结果如下所示:

消融实验

         作者对模型的主要模块进行了消融实验,结果如下:

        其中w/o span表示使用传统的token-wise 原型网络框架;w/o Metion classifier表示直接利用k-means对span进行分类(这一块的影响最大);w/o Margin-based learning表示去掉了相关的损失函数。

其他实验

         作者对span extractor的效果与其他方法(ESD:利用滑动窗口来获取候选span;DecomMeta:利用BIOE标注方法来检测span)进行了对比实验:

        为了探究span extractor是怎么学习span边界,作者对数据集中的某个query sentence的预测全局边界矩阵进行了可视化,如下所示: 

         作者对预训练步数T和阈值\theta这两个超参数进行了实验,结果如下所示:

         对margin-based learning中的超参数r进行了实验,结果如下所示:

        

        对span 表征在语义空间进行了可视化,如下所示:

        可以看出还是能够比较好的分离出false positive样本。

         对FP-span(提取出边界错误的实体)和FP-type(表示边界正确,但类别错误)这两种错误进行统计分析,结果如下所示:

总结

        不是很懂Few-shot的处理方式,也不是很明白原型学习,而且文中说的margin-based
learning objective这个翻译为基于边界的学习目标,让我有点迷糊,但好像就是使false positive的表征离原型向量尽量远。

### 关于Prototypical Networks for Few-Shot Learning论文 Prototypical Networks 是一种针对小样本学习(Few-Shot Learning)设计的方法,旨在通过构建类别原型来完成新类别的快速识别[^1]。该方法的核心思想是在嵌入空间中为每个类别定义一个原型向量,并利用这些原型向量计算输入样本与各个类别的相似度。 #### 论文下载方式 Prototypical Networks 的原始论文《Prototypical Networks for Few-shot Learning》可以在以下链接找到并下载: - 原始论文地址:[https://arxiv.org/abs/1703.05175](https://arxiv.org/abs/1703.05175) 如果无法直接访问上述链接,可以尝试通过其他学术资源网站获取PDF版本,例如Google Scholar 或者 ResearchGate。此外,一些开源项目也提供了对该论文的解读和实现细节说明[^2][^3]。 #### 可视化的实现 对于可视化部分,作者展示了在嵌入空间中的 t-SNE 图像,其中类别原型用黑色表示,而错误分类的字符则用红色高亮显示[^4]。要实现类似的可视化效果,可以通过以下步骤: 1. **提取嵌入特征** 使用训练好的模型对测试数据进行前向传播,得到每张图像对应的嵌入向量。 2. **应用t-SNE降维** 利用 `sklearn.manifold.TSNE` 将高维嵌入向量投影到二维或三维空间以便绘制图形。 3. **绘图展示** 使用 Matplotlib 或 Seaborn 绘制散点图,在图中标记出不同类别的原型位置以及误分类的数据点。 以下是基于PyTorch实现的一个简单代码示例: ```python import torch from sklearn.manifold import TSNE import matplotlib.pyplot as plt def visualize_embeddings(embeddings, labels, prototypes): """ embeddings: 测试样本的嵌入向量 (NxD) labels: 对应的真实标签 (Nx1) prototypes: 类别原型向量 (KxD) """ all_data = torch.cat([embeddings, prototypes], dim=0).cpu().numpy() tsne = TSNE(n_components=2, random_state=42) embedded_data = tsne.fit_transform(all_data) # 提取样本点和原型点的位置 sample_points = embedded_data[:len(embeddings)] prototype_points = embedded_data[len(embeddings):] # 绘制散点图 plt.figure(figsize=(8, 6)) unique_labels = set(labels.cpu().numpy()) colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels))) for i, color in zip(unique_labels, colors): idx = np.where(labels == i)[0] plt.scatter(sample_points[idx, 0], sample_points[idx, 1], c=color, label=f'Class {i}') # 添加原型点标记 plt.scatter(prototype_points[:, 0], prototype_points[:, 1], marker='*', s=200, c='black', label='Prototypes') plt.legend() plt.title('t-SNE Visualization of Embedding Space') plt.show() ``` 此函数接受嵌入向量、真实标签以及类别原型作为输入参数,并生成相应的t-SNE可视化图表。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值