Generating Names with a Character-Level RNN

本文介绍了一个基于PyTorch的姓名生成器模型,该模型通过训练不同国家的名字数据集来学习名字的生成规律。文章详细展示了数据预处理流程、模型训练步骤及预测方法,并讨论了模型的工作原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录结构:

1.准备数据

2.创建网络

3.训练

3.1为训练做准备

3.2训练网络

3.3画loss损失率

4.检验网络


3.1为训练做准备

EOS是结束的字母

3.2训练网络

all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker

criterion = nn.NLLLoss()

learning_rate = 0.0005

def train(category_tensor, input_line_tensor, target_line_tensor):
    hidden = rnn.initHidden()

    rnn.zero_grad()

    loss = 0

    for i in range(input_line_tensor.size()[0]):
        output, hidden = rnn(category_tensor, input_line_tensor[i], hidden)
        loss += criterion(output, target_line_tensor[i])

    loss.backward()

    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)

    return output, loss.data[0] / input_line_tensor.size()[0]
rnn = RNN(n_letters, 128, n_letters)

n_iters = 100000
print_every = 5000
plot_every = 500
all_losses = []
total_loss = 0 # Reset every plot_every iters

start = time.time()

for iter in range(1, n_iters + 1):
    output, loss = train(*randomTrainingExample())
    total_loss += loss

    if iter % print_every == 0:
        print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))

    if iter % plot_every == 0:
        all_losses.append(total_loss / plot_every)
        total_loss = 0
初始化网络时,是52个英文字母加上6个符号,还加上一个结束符号EOS,然后n_letters=59。

第一步是先随机的取训练样本,函数定义如下,

# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = torch.zeros(1, n_categories)
    tensor[0][li] = 1
    return tensor

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return tensor

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return torch.LongTensor(letter_indexes)

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
    category, line = randomTrainingPair()
    category_tensor = Variable(categoryTensor(category))
    input_line_tensor = Variable(inputTensor(line))
    target_line_tensor = Variable(targetTensor(line))
    return category_tensor, input_line_tensor, target_line_tensor

ipdb> category
u'Italian'
ipdb> line
u'Vescovi'

ipdb> category
u'Greek'
ipdb> line
u'Christodoulou'#13个字母
ipdb> all_categories
[u'Spanish', u'Italian', u'Korean', u'French', u'Japanese', u'Polish', u'Scottish', u'English', u'Portuguese', u'Vietnamese', u'German', u'Dutch', u'Chinese', u'Czech', u'Arabic', u'Irish', u'Greek', u'Russian']
ipdb> category_lines['Italian']
[u'Abandonato', u'Abatangelo', u'Abatantuono',..., u'Zunino']
ipdb> randomChoice([u'Abandonato', u'Abatangelo', u'Abatantuono'])
u'Abandonato'
ipdb> category_tensor
Variable containing:

Columns 0 to 12 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 13 to 17 
    0     0     0     1     0
[torch.FloatTensor of size 1x18]
ipdb> input_line_tensor
Variable containing:
(0 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(1 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   1   0   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(2 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   1   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(3 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(4 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   1

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(5 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(6 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(7 ,.,.) = 

Columns 0 to 18 
    0   0   0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(8 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(9 ,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(10,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(11,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0   0

Columns 19 to 37 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0

(12,.,.) = 

Columns 0 to 18 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 19 to 37 
    0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 38 to 56 
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0

Columns 57 to 58 
    0   0
[torch.FloatTensor of size 13x1x59]
ipdb> target_line_tensor
Variable containing:
  7
 17
  8
 18
 19
 14
  3
 14
 20
 11
 14
 20
 58
[torch.LongTensor of size 13]

注意以下两个语句的区别

ipdb> [all_letters.find(line[li]) for li in range(1, len(line))]
[7, 17, 8, 18, 19, 14, 3, 14, 20, 11, 14, 20]

ipdb> [all_letters.find(line[li]) for li in range(0, len(line))]
[28, 7, 17, 8, 18, 19, 14, 3, 14, 20, 11, 14, 20]

接下来把最后一个EOS字符给算上

ipdb> letter_indexes.append(n_letters - 1)

ipdb> letter_indexes
[7, 17, 8, 18, 19, 14, 3, 14, 20, 11, 14, 20, 58]

输入train的分别是1X18,13X1X59,13

输入rnn的则是则是前向传播13次,每次输入网络的是1X18,1X59,hidden([torch.FloatTensor of size 1x128])。

网络的输出如下,

ipdb> output
Variable containing:

Columns 0 to 9 
-4.0722 -3.9961 -4.0710 -4.0514 -4.0858 -4.0823 -4.1343 -4.1031 -4.0499 -4.0209

Columns 10 to 19 
-3.9813 -4.0423 -4.0766 -4.0675 -4.0792 -4.1629 -4.1008 -3.9918 -4.0858 -4.0858

Columns 20 to 29 
-4.0786 -4.1626 -3.9722 -4.1445 -4.0758 -4.0858 -3.9976 -4.1414 -4.0328 -4.1664

Columns 30 to 39 
-4.1335 -4.0858 -4.0842 -4.1305 -4.1626 -4.0840 -4.0858 -4.1485 -4.1114 -4.0157

Columns 40 to 49 
-4.1969 -4.1047 -4.0804 -4.2013 -3.9460 -4.1410 -4.1158 -4.1407 -4.0615 -4.0736

Columns 50 to 58 
-4.0753 -4.0336 -3.9570 -4.1130 -4.0925 -4.1859 -3.9428 -3.9744 -4.0120
[torch.FloatTensor of size 1x59]
ipdb> hidden
Variable containing:

Columns 0 to 9
0.0197 0.0799 0.0579 0.0357 -0.0959 -0.0050 -0.0533 -0.1187 0.0021 0.0350

Columns 10 to 19
0.0298 -0.0092 0.0757 -0.0314 -0.0242 0.1162 -0.0419 0.0763 -0.0654 0.0257

Columns 20 to 29
-0.1355 0.0544 -0.0692 -0.0773 0.0837 0.0296 -0.0235 0.0364 -0.0055 0.0206

Columns 30 to 39
0.0798 -0.0957 0.0915 -0.0692 -0.0200 0.0987 -0.0353 0.0346 0.0153 0.1207

Columns 40 to 49
-0.0094 -0.0378 -0.0668 -0.0594 0.0420 0.0532 0.0934 0.0240 -0.0304 -0.0596

Columns 50 to 59
-0.0583 0.0303 0.0805 0.0204 0.0338 0.0935 -0.0100 0.0456 0.0825 0.1461

Columns 60 to 69
-0.0810 -0.0256 0.0232 -0.0264 0.0462 0.0504 -0.1145 0.0272 0.0526 -0.1539

Columns 70 to 79
0.0266 0.0593 -0.0905 -0.0160 0.0949 0.0200 -0.0682 0.0745 -0.0891 0.0204

Columns 80 to 89
-0.0660 -0.0178 -0.0392 0.0662 0.0092 0.0311 -0.1366 -0.0343 0.0106 0.0674

Columns 90 to 99
-0.0050 0.0402 -0.0840 0.0015 0.0300 -0.1098 0.0163 0.0493 0.0760 -0.0958

Columns 100 to 109
-0.0603 0.0409 -0.0110 0.0887 0.0214 -0.0622 0.0281 0.0328 0.0616 -0.0216

Columns 110 to 119
-0.0605 0.1060 0.0050 0.0028 0.0396 0.0168 0.1111 -0.1224 0.0224 -0.0260

Columns 120 to 127
-0.0552 -0.1317 0.0320 -0.0143 0.0014 -0.0062 0.0658 -0.0569
[torch.FloatTensor of size 1x128]

网络训练过程的总结:随即输入一个姓名和其对应的国籍,然后依次取出其姓名的字母和国籍送入到网络中,输出一个1X59维大小的输出,此时输出的意义是下一个字母应该出现的概率,而将下一个字母作为标签,进行loss的计算。


那么网络的预测又是如何进行的呢?

max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    category_tensor = Variable(categoryTensor(category))
    input = Variable(inputTensor(start_letter))
    hidden = rnn.initHidden()

    output_name = start_letter

    for i in range(max_length):
        output, hidden = rnn(category_tensor, input[0], hidden)
        topv, topi = output.data.topk(1)
        topi = topi[0][0]
        if topi == n_letters - 1:
            break
        else:
            letter = all_letters[topi]
            output_name += letter
        input = Variable(inputTensor(letter))

    return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')
Out:

Rovano
Uakonov
Santonov
Gerter
Eres
Rour
Sala
Para
Aller
Chan
Han
Iun

可以看出,预测的结果是输入国籍和首字母,便输出下一个最可能的字母,如果是EOS,则结束前向传播,输出产生的姓名。这儿注重的显然是前后字母的这种联系。如果单纯是考虑前后字母的联系,缺少整体的宏观,肯定是不行的,那么此时我们希望隐含层的信息能够包含整体宏观的这种信息。

其实,上述过程,不如我进行统计,如果输入国籍和当前字母,那么就统计出下一个字母的频率,问题在于这个就是单纯统计了前后字母的这种联系,缺少字母的位置信息,这样产生的结果当然是有问题的。


翻译自:http://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html#exercises



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值