NLP-Task3:基于注意力机制的文本匹配

本文介绍了一种基于注意力机制的文本匹配方法——ESIM模型,并通过斯坦福大学的SNLI数据集进行了验证。该模型利用双向LSTM进行特征提取,并采用局部推理建模等步骤完成文本匹配任务。

NLP-Task3:基于注意力机制的文本匹配

输入两个句子进行判断它们之间的关系,用双向注意力机制实现

数据集:https://nlp.stanford.edu/projects/snli/
参考论文:Enhanced LSTM for Natural Language Inference
论文解析:Note

  • 知识点
    • 注意力机制
    • token2token attention
本文参考:NLP-Beginner 任务三:基于注意力机制的文本匹配+pytorch

一、任务介绍

本次任务主要利用论文中提出的ESIM模型进行文本匹配

1.1 数据集

训练集共有55万余项,匹配关系共有四种:包含(Entailment)、矛盾(contradiction)、中立(Neutral)、未知(-)

二、特征提取—Word embedding

三、神经网络

参考论文Enhanced LSTM for Natural Language Inference

3.1 Input Encoding

3.2 Local Inference Modeling

3.3 Inference Composition

3.4 Output

3.5 训练神经网络

使用交叉熵损失函数进行训练

四、代码

4.1 实验设置

  • 样本个数:约55w
  • 训练集与测试集:7:3
  • 模型:ESIM
  • 初始化:随机初始化、GloVe预训练模型初始化
  • 学习率: 1 0 − 3 10^{-3} 103
  • Batch size:1000

4.2 具体代码

#feature
import random
import re
import torch
from torch.utils.data import Dataset, DataLoader

def data_split(data, test_rate=0.3):
    """ Take some data , and split them into training set and test set."""
    train = list()
    test = list()
    i = 0
    for datum in data:
        i += 1
        if random.random() > test_rate:
            train.append(datum)
        else:
            test.append(datum)
    return train, test


class Random_embedding():
    def __init__(self, data, test_rate=0.3):
        self.dict_words = dict()
        _data = [item.split(',') for item in data]
        self.data = [[item[5], item[6], item[0]] for item in _data]
        self.len_words = 0
        self.train, self.test = data_split(self.data, test_rate=test_rate)
        self.type_dict = {
   
   '-': 0, 'contradiction': 1, 'entailment': 2, 'neutral': 3}
        self.train_y = [self.type_dict[term[2]] for term in self.train]  # Relation in training set
        self.test_y = [self.type_dict[term[2]] for term in self.test]  # Relation in test set
        self.train_s1_matrix = list()
        self.test_s1_matrix = list()
        self.train_s2_matrix = list()
        self.test_s2_matrix = list()
        self.longest = 50

    def get_words(self):
        pattern = '[A-Za-z|\']+'
        for term in self.data:
            for i in range(2):
                s = term[i]
                s = s.upper()
                words = re.findall(pattern, s)
                for word in words:  # Process every word
                    if word not in self.dict_words:
                        self.dict_words[word] = len(self.dict_words)
        self.len_words = len(self.dict_words)

    def get_id(self):
        pattern = '[A-Za-z|\']+'
        for term in self.train:
            s = term[0]
            s = s.upper()
            words = re.findall(pattern, s)
            item = [self.dict_words[word] for word in words]
            item+=[self.len_words for _ in range(self.longest-len(item))]
            self.longest = max(self.longest, len(item))
            self.train_s1_matrix.append(item)
            s = term[1]
            s = s.upper()
            words = re.findall(pattern, s)
            item = [self.dict_words[word] for word in words]
            item += [self.len_words for _ in range(self.longest - len(item))]
            self.longest = max(self.longest, len(item))
            self.train_s2_matrix.append(item)
        for term in self.test:
            s = term[0]
            s = s.upper()
            words = re.findall(pattern, s)
            item = [self.dict_words[word] for word in words]
            item += [self.len_words for _ in range(self.longest - len(item))]
            self.longest = max(self.longest, len(item))
            self.test_s1_matrix.append(item)
            s = term[1]
            s = s.upper()
            words = re.findall(pattern, s)
            item = [self.dict_words[word] for word in words]
            item += [self.len_words for _ in range(self.longest - len(item))]
            self.longest = max(self.longest, len(item)
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值