深入理解AI-For-Beginners项目中的生成式神经网络
生成式网络概述
在自然语言处理领域,循环神经网络(RNN)及其变体如长短期记忆网络(LSTM)和门控循环单元(GRU)为语言建模提供了强大工具。这些网络能够学习单词顺序,并预测序列中的下一个单词,这使得它们能够用于生成式任务,如普通文本生成、机器翻译甚至图像字幕生成。
在传统的RNN架构中,每个RNN单元都会产生下一个隐藏状态作为输出。然而,我们还可以为每个循环单元添加另一个输出,使我们能够输出一个序列(长度与原始序列相等)。更进一步,我们可以使用不接收每一步输入,仅接受初始状态向量并产生一系列输出的RNN单元。
字符级文本生成
本教程重点介绍简单的生成模型,帮助我们逐字符生成文本。为了简化问题,我们将构建一个字符级网络,它逐个字符地生成文本。在训练过程中,我们需要获取一些文本语料库,并将其拆分为字符序列。
构建字符词汇表
要构建字符级生成网络,我们需要将文本拆分为单个字符而非单词。这可以通过定义不同的分词器来实现:
def char_tokenizer(words):
return list(words)
通过这种方式,我们可以统计训练数据集中所有字符的出现频率,并构建字符到索引的映射关系。例如,字符'a'可能被编码为1,而索引13可能对应字符'c'。
数据编码示例
让我们看一个如何编码数据集中文本的示例:
def enc(x):
return torch.LongTensor(encode(x,voc=vocab,tokenizer=char_tokenizer))
这个函数将输入文本转换为对应的字符索引张量,便于神经网络处理。
训练生成式RNN
我们训练RNN生成文本的方式如下:在每一步,我们取长度为nchars
的字符序列,并要求网络为每个输入字符生成下一个输出字符。
每个训练样本将由nchars
个输入和nchars
个输出(输入序列向左移动一个符号)组成。小批量(minibatch)将由几个这样的序列组成。
小批量生成
我们生成小批量的方法是:取每个长度为l
的新闻文本,从中生成所有可能的输入-输出组合(将有l-nchars
个这样的组合)。它们将形成一个minibatch,每个训练步骤的minibatch大小会有所不同。
nchars = 100
def get_batch(s,nchars=nchars):
ins = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)
outs = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)
for i in range(len(s)-nchars):
ins[i] = enc(s[i:i+nchars])
outs[i] = enc(s[i+1:i+nchars+1])
return ins,outs
定义生成器网络
现在让我们定义生成器网络。它可以基于我们在上一单元讨论的任何循环单元(简单RNN、LSTM或GRU)。在我们的示例中,我们将使用LSTM。
由于网络将字符作为输入,且词汇表相当小,我们不需要嵌入层,可以直接将one-hot编码的输入传递给LSTM单元。输出编码器将是一个线性层,将隐藏状态转换为one-hot编码的输出。
class LSTMGenerator(torch.nn.Module):
def __init__(self, vocab_size, hidden_dim):
super().__init__()
self.rnn = torch.nn.LSTM(vocab_size,hidden_dim,batch_first=True)
self.fc = torch.nn.Linear(hidden_dim, vocab_size)
def forward(self, x, s=None):
x = torch.nn.functional.one_hot(x,vocab_size).to(torch.float32)
x,s = self.rnn(x,s)
return self.fc(x),s
文本生成函数
在训练过程中,我们希望能够采样生成的文本。为此,我们将定义generate
函数,该函数将从初始字符串start
开始生成长度为size
的输出字符串。
它的工作方式如下:首先,我们将整个起始字符串通过网络传递,获取输出状态s
和下一个预测字符out
。由于out
是one-hot编码的,我们取argmax
获取词汇表中字符nc
的索引,并使用itos
找出实际字符并将其附加到结果字符列表chars
中。重复此过程size
次以生成所需数量的字符。
def generate(net,size=100,start='today '):
chars = list(start)
out, s = net(enc(chars).view(1,-1).to(device))
for i in range(size):
nc = torch.argmax(out[0][-1])
chars.append(vocab.get_itos()[nc])
out, s = net(nc.view(1,-1),s)
return ''.join(chars)
训练过程
训练循环与我们之前的示例几乎相同,但每1000个epoch我们会打印一次生成的文本样本。
特别需要注意的是我们计算损失的方式。我们需要根据one-hot编码的输出out
和预期文本text_out
(字符索引列表)计算损失。幸运的是,cross_entropy
函数期望第一个参数是未归一化的网络输出,第二个参数是类别编号,这正是我们拥有的。它还会自动对minibatch大小进行平均。
我们还通过samples_to_train
限制训练样本数量,以避免等待时间过长。鼓励您尝试更长时间的训练,可能进行多个epoch(在这种情况下,您需要在此代码周围创建另一个循环)。
改进方向
这个示例已经生成了一些相当不错的文本,但可以通过以下几种方式进一步改进:
-
更好的minibatch生成:我们准备训练数据的方式是从一个样本生成一个minibatch。这并不理想,因为minibatch的大小都不同,有些甚至无法生成,因为文本小于
nchars
。此外,小的minibatch不能充分加载GPU。更明智的做法是从所有样本中获取一大块文本,然后生成所有输入-输出对,打乱它们,并生成大小相等的minibatch。 -
多层LSTM:尝试2或3层LSTM单元是有意义的。正如我们在上一单元中提到的,每一层LSTM从文本中提取特定模式,在字符级生成器的情况下,我们可以期望较低的LSTM级别负责提取音节,而较高级别负责单词和单词组合。这可以通过将层数参数传递给LSTM构造函数来实现。
-
实验不同单元和隐藏层大小:您可能还想尝试GRU单元,看看哪种表现更好,以及不同的隐藏层大小。过大的隐藏层可能导致过拟合(例如网络将学习确切的文本),而过小的尺寸可能不会产生好的结果。
软文本生成与温度参数
在之前的generate
定义中,我们总是将概率最高的字符作为生成文本中的下一个字符。这导致文本经常在相同的字符序列之间"循环",如下例所示:
today of the second the company and a second the company ...
然而,如果我们查看下一个字符的概率分布,几个最高概率之间的差异可能不大,例如一个字符可能有0.2的概率,另一个有0.19的概率。例如,当在序列'play'中寻找下一个字符时,下一个字符同样可能是空格或e(如单词player中)。
这使我们得出结论,选择概率较高的字符并不总是"公平"的,因为选择第二高的可能仍然会导致有意义的文本。更明智的做法是从网络输出给出的概率分布中采样字符。
这种采样可以使用实现所谓多项分布的multinomial
函数来完成。实现这种软文本生成的函数定义如下:
def generate_soft(net, size=100, start='today ', temperature=1.0):
chars = list(start)
out, s = net(enc(chars).view(1,-1).to(device))
for i in range(size):
probs = torch.nn.functional.softmax(out[0][-1]/temperature, dim=0)
nc = torch.multinomial(probs,1).item()
chars.append(vocab.get_itos()[nc])
out, s = net(torch.tensor([[nc]],device=device),s)
return ''.join(chars)
通过调整温度参数,我们可以控制生成文本的随机性程度:
- 低温(如0.3):网络更倾向于选择概率最高的字符,生成文本更加保守和可预测
- 高温(如1.8):网络更倾向于探索各种可能性,生成文本更加随机和创造性
总结
本教程详细介绍了如何使用PyTorch实现字符级文本生成模型。我们从基础概念出发,逐步构建了完整的生成式神经网络,并探讨了多种改进方法。通过调整网络架构、训练策略和生成参数,我们可以获得不同风格的生成文本。
生成式神经网络在自然语言处理领域有着广泛的应用,从简单的文本生成到复杂的对话系统和机器翻译。理解这些基础模型的工作原理对于进一步探索更先进的生成技术(如Transformer和GPT模型)至关重要。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考