实验环境
- Pytorch 1.4.0
- conda 4.7.12
- Jupyter Notebook 6.0.1
- Python 3.7
数据集介绍
来源豆瓣电影评论,数据集包括:
- 训练集:包含 2W 条左右中文电影评论,其中正负向评论各占 1/2。
- 验证集:包含 6K 条左右中文电影评论,其中正负向评论各占 1/2。
- 测试集:包含 360 条左右中文电影评论,其中正负向评论各占 1/2。
- 预训练词向量:中文维基百科词向量 word2vec。
训练过程
数据准备
首先,导入实验所需的库。
import gensim
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from collections import Counter
from torch.utils.data import TensorDataset,DataLoader
from torch.optim.lr_scheduler import *
- 构建 word to id 词汇表并存储,形如 word: id。file: word2id 保存地址,save_to_path: 保存训 练语料库中的词组对应的 word2vec 到本地。
def build_word2id(file, save_to_path=None):
"""
:param file: word2id保存地址
:param save_to_path: 保存训练语料库中的词组对应的word2vec到本地
:return: None
"""
word2id = {
'_PAD_': 0}
path = ['./Dataset/train.txt', './Dataset/validation.txt']
for _path in path:
with open(_path, encoding='utf-8') as f:
for line in f.readlines():
sp = line.strip().split()
for word in sp[1:]:
if word not in word2id.keys():
word2id[word] = len(word2id)
if save_to_path:
with open(file, 'w', encoding='utf-8') as f:
for w in word2id:
f.write(w+'\t')
f.write(str(word2id[w]))
f.write('\n')
return word2id
- 基于预训练的 word2vec 构建训练语料中所含词向量,fname: 预训练的 word2vec,word2id: 语 料文本中包含的词汇集,save_to_path: 保存训练语料库中的词组对应的 word2vec 到本地,语料文本 中词汇集对应的 word2vec 向量 id: word2vec。
def build_word2vec(fname, word2id, save_to_path=None):
"""
:param fname: 预训练的word2vec.
:param word2id: 语料文本中包含的词汇集.
:param save_to_path: 保存训练语料库中的词组对应的word2vec到本地
:return: 语料文本中词汇集对应的word2vec向量{
id: word2vec}.
"""
n_words = max(word2id.values()) + 1
model = gensim.models.KeyedVectors.load_word2vec_format(fname, binary=True)
word_vecs = np.array(np.random.uniform(-1., 1., [n_words, model.vector_size]))
for word in word2id.keys():
try:
word_vecs[word2id[word]] = model[word]
except KeyError:
pass
if save_to_path:
with open(save_to_path, 'w', encoding='utf-8') as f:
for vec in word_vecs:
vec = [str(w) for w in vec]
f.write(' '.join(vec))
f.write('\n')
return word_vecs
- 分类类别以及 id 对应词典 pos:0, neg:1,classes: 分类标签;默认为 0:pos, 1:neg,返回分类标 签:id。
def cat_to_id(classes=None):
"""
:param classes: 分类标签;默认为0:pos, 1:neg
:return: {
分类标签:id}
"""
if not classes:
classes = ['0',