RNN --多分类的问题 名字对应国家

原作者刘老师是在一个文件中实现的。 代码来源:https://blog.youkuaiyun.com/Oscar6280868/article/details/108000451
我发布的该工程有俩个文件,其中有部分代码是重合的。因为函数之间的嵌套,没有去做优化。

在这里插入图片描述

tain_model.py

import gzip
import csv
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader
from torchtext.data import Dataset

# 隐藏层的维度
HIDDEN_SIZE = 100
BATCH_SIZE = 256
# RNN的层数
N_LAYERS = 2
# 训练的轮数,暂定500轮
N_EPOCHS = 500
# 字符长度,也就是输入的维度
N_CHARS = 128
# 是否使用GPU
USE_GPU = True     #**此处使用GPU训练的速度会快一点  ,后面预测的时候又 使用了 CPU 因为张量的问题在代码中没有解决**


res =[]
is_train_set = True


class NameDataset(Dataset):
    def __init__(self, is_train_set=True):
    	# 指定训练集和测试集
        file_name = 'D:/1学习文档/nlp/数据集/刘二/names_train.csv.gz' if is_train_set else 'D:/1学习文档/nlp/数据集/刘二/names_test.csv.gz'
        with gzip.open(file_name, 'rt') as f:
            reader = csv.reader(f)
            rows = list(reader)
        # 人名
        self.names = [row[0] for row in rows]
        # 人名序列的长度
        self.length = len(self.names)
        # 人名所对应的国家
        self.countries = [row[1] for row in rows]
        # 所有国家的集合
        self.country_list = list(sorted(set(self.countries)))
        # 国家和index生成的字典
        self.country_dict = self.getCountryDict()
        # 国家的数量
        self.country_num = len(self.country_list)

    def __getitem__(self, index):
    	# 返回人名和所对应的国家名
        return self.names[index], self.country_dict[self.countries[index]]

    def __len__(self):
    	# 返回人名的长度
        return self.length

    def getCountryDict(self):
        country_dict = {
   
   }
        # 遍历数据建立国家的字典
        for idx, country_name in enumerate(self.country_list, 0):
            country_dict[country_name] = idx
        return country_dict

    def idx2country(self, index):
    	# 将index转换成国家名
        return self.country_list[index]

    def getCountryNum(self):
    	# 获取不同国家的数量
        return self.country_num

# 建立训练集的dataloader
train_set = NameDataset(is_train_set=True)
trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
# 建立测试集的dataloader
test_set = NameDataset(is_train_set=False)
testloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
# 获取国家数
N_COUNTRY = train_set.getCountryNum()

class RNNClassifier(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
        super(RNNClassifier, self).__init__()
        # RNN隐藏层的维度
        self.hidden_size = hidden_size
        # 有多少层RNN
        self.n_layers = n_layers
        # 是否使用双向RNN
        self.n_directions = 2 if bidirectional else 1
        # 将序列进行embedding操作,维度为(seq_length(input_size), batch_size, hidden_size),input_size为字典的大小
        self.embedding = torch.nn.Embedding(input_size, hidden_size
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值