原作者刘老师是在一个文件中实现的。 代码来源:https://blog.youkuaiyun.com/Oscar6280868/article/details/108000451
我发布的该工程有俩个文件,其中有部分代码是重合的。因为函数之间的嵌套,没有去做优化。
tain_model.py
import gzip
import csv
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader
from torchtext.data import Dataset
# 隐藏层的维度
HIDDEN_SIZE = 100
BATCH_SIZE = 256
# RNN的层数
N_LAYERS = 2
# 训练的轮数,暂定500轮
N_EPOCHS = 500
# 字符长度,也就是输入的维度
N_CHARS = 128
# 是否使用GPU
USE_GPU = True #**此处使用GPU训练的速度会快一点 ,后面预测的时候又 使用了 CPU 因为张量的问题在代码中没有解决**
res =[]
is_train_set = True
class NameDataset(Dataset):
def __init__(self, is_train_set=True):
# 指定训练集和测试集
file_name = 'D:/1学习文档/nlp/数据集/刘二/names_train.csv.gz' if is_train_set else 'D:/1学习文档/nlp/数据集/刘二/names_test.csv.gz'
with gzip.open(file_name, 'rt') as f:
reader = csv.reader(f)
rows = list(reader)
# 人名
self.names = [row[0] for row in rows]
# 人名序列的长度
self.length = len(self.names)
# 人名所对应的国家
self.countries = [row[1] for row in rows]
# 所有国家的集合
self.country_list = list(sorted(set(self.countries)))
# 国家和index生成的字典
self.country_dict = self.getCountryDict()
# 国家的数量
self.country_num = len(self.country_list)
def __getitem__(self, index):
# 返回人名和所对应的国家名
return self.names[index], self.country_dict[self.countries[index]]
def __len__(self):
# 返回人名的长度
return self.length
def getCountryDict(self):
country_dict = {
}
# 遍历数据建立国家的字典
for idx, country_name in enumerate(self.country_list, 0):
country_dict[country_name] = idx
return country_dict
def idx2country(self, index):
# 将index转换成国家名
return self.country_list[index]
def getCountryNum(self):
# 获取不同国家的数量
return self.country_num
# 建立训练集的dataloader
train_set = NameDataset(is_train_set=True)
trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
# 建立测试集的dataloader
test_set = NameDataset(is_train_set=False)
testloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
# 获取国家数
N_COUNTRY = train_set.getCountryNum()
class RNNClassifier(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
super(RNNClassifier, self).__init__()
# RNN隐藏层的维度
self.hidden_size = hidden_size
# 有多少层RNN
self.n_layers = n_layers
# 是否使用双向RNN
self.n_directions = 2 if bidirectional else 1
# 将序列进行embedding操作,维度为(seq_length(input_size), batch_size, hidden_size),input_size为字典的大小
self.embedding = torch.nn.Embedding(input_size, hidden_size