第四章:大模型(LLM)
第二部分:神经网络中的 NLP
第二节:Seq2Seq 原理及代码解析
1. Seq2Seq(Sequence-to-Sequence)模型原理
Seq2Seq 是一种处理序列到序列任务(如机器翻译、文本摘要、对话生成等)的深度学习架构,最早由 Google 在 2014 年提出。其核心思想是使用 编码器(Encoder) 将输入序列编码为上下文向量,再通过 解码器(Decoder) 逐步生成输出序列。
1.1 架构组成
-
编码器(Encoder)
-
通常是 RNN、LSTM 或 GRU。
-
输入:序列
。
-
输出:隐藏状态
,作为上下文向量。
-
-
解码器(Decoder)
-
结构类似于编码器。
-
输入:编码器输出的上下文向量 + 上一步预测的输出。
-
输出:目标序列
。
-
-
上下文向量(Context Vector)
-
编码器最后一个隐藏状态
作为整个输入序列的信息摘要。
-
2. 数学公式
-
编码器:
-
解码器:
其中 c 是上下文向量。
3. 经典 Seq2Seq 训练流程
-
输入序列通过编码器,生成上下文向量。
-
解码器利用上下文向量和前一时刻的预测结果,逐步生成输出。
-
使用 教师强制(Teacher Forcing) 技术,训练时将真实标签输入解码器。
4. 改进:Attention 机制
Seq2Seq 传统模型存在 长序列信息丢失 问题。
Attention 通过在每一步解码时为输入序列不同部分分配权重,解决了这个问题。
公式:
其中 是注意力权重。
5. PyTorch 代码解析:Seq2Seq 示例
import torch
import torch.nn as nn
import torch.optim as optim
# Encoder
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers=1):
super(Encoder, self).__init__()
self.rnn = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
def forward(self, x):
outputs, hidden = self.rnn(x)
return hidden
# Decoder
class Decoder(nn.Module):
def __init__(self, output_dim, hidden_dim, num_layers=1):
super(Decoder, self).__init__()
self.rnn = nn.GRU(output_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x, hidden):
output, hidden = self.rnn(x, hidden)
pred = self.fc(output)
return pred, hidden
# Seq2Seq
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, trg):
hidden = self.encoder(src)
outputs, _ = self.decoder(trg, hidden)
return outputs
# Example usage
input_dim, output_dim, hidden_dim = 10, 10, 32
encoder = Encoder(input_dim, hidden_dim)
decoder = Decoder(output_dim, hidden_dim)
model = Seq2Seq(encoder, decoder)
src = torch.randn(16, 20, input_dim) # batch=16, seq_len=20
trg = torch.randn(16, 20, output_dim)
output = model(src, trg)
print(output.shape) # [16, 20, 10]
6. 应用场景
-
机器翻译(Google Translate)
-
文本摘要(新闻摘要生成)
-
对话系统(聊天机器人)
-
语音识别(语音到文本)