告别调参烦恼:PyTorch Transformer实战指南从BERT到GPT

告别调参烦恼:PyTorch Transformer实战指南从BERT到GPT

【免费下载链接】pytorch Python 中的张量和动态神经网络,具有强大的 GPU 加速能力 【免费下载链接】pytorch 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch

你是否在训练Transformer模型时遇到过这些问题:参数调来调去效果还是不好?BERT和GPT看起来相似却不知道如何选择?本文将带你从基础开始,一步步掌握PyTorch中Transformer的实现,轻松上手BERT和GPT模型。读完本文,你将能够:

  • 理解Transformer的核心结构和工作原理
  • 使用PyTorch快速搭建BERT和GPT模型
  • 掌握模型调优的关键技巧
  • 解决实际应用中常见的问题

Transformer基础架构

Transformer是一种基于自注意力机制(Self-Attention Mechanism)的神经网络模型,由Google团队在2017年提出。与传统的RNN和CNN相比,Transformer能够更好地捕捉长距离依赖关系,同时具有高度的并行性,极大地提高了训练效率。

Transformer核心组件

PyTorch的torch.nn模块提供了完整的Transformer实现,主要包括以下组件:

# Transformer核心组件
from torch.nn import Transformer, TransformerEncoder, TransformerDecoder
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer
  • TransformerEncoderLayer:编码器的基本单元,包含多头自注意力和前馈神经网络
  • TransformerDecoderLayer:解码器的基本单元,除了编码器包含的组件外,还有编码器-解码器注意力层
  • TransformerEncoder:由多个EncoderLayer堆叠而成
  • TransformerDecoder:由多个DecoderLayer堆叠而成
  • Transformer:完整的Transformer模型,包含编码器和解码器

Transformer工作原理

Transformer的工作流程可以分为以下几个步骤:

  1. 输入序列通过嵌入层(Embedding)转换为向量表示
  2. 添加位置编码(Positional Encoding)以保留序列顺序信息
  3. 编码器处理输入序列,生成上下文向量
  4. 解码器利用编码器的输出和自身输入生成目标序列

下面是PyTorch中Transformer的基本使用示例:

# Transformer基本使用示例
import torch
import torch.nn as nn

# 定义模型参数
d_model = 512  # 模型维度
nhead = 8      # 注意力头数
num_layers = 6 # 编码器/解码器层数

# 创建Transformer模型
transformer = nn.Transformer(d_model, nhead, num_layers)

# 随机生成输入数据
src = torch.rand(10, 32, d_model)  # 源序列: (序列长度, 批次大小, 特征维度)
tgt = torch.rand(20, 32, d_model)  # 目标序列: (序列长度, 批次大小, 特征维度)

# 模型前向传播
output = transformer(src, tgt)
print(output.shape)  # 输出形状: (20, 32, 512)

BERT实现:双向Transformer的应用

BERT(Bidirectional Encoder Representations from Transformers)是Google在2018年提出的预训练语言模型,它利用Transformer的编码器部分,通过双向注意力机制学习文本表示。

BERT模型结构

BERT的核心是多层Transformer编码器,其结构定义在torch/nn/modules/transformer.py中。下面是使用PyTorch实现BERT基础结构的示例:

# BERT基础结构实现
class BERT(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=12):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Embedding(512, d_model)  # 位置编码
        
        # 创建Transformer编码器层
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead,
            dim_feedforward=3072,  # 前馈网络维度
            batch_first=True       # 批次优先
        )
        
        # 创建Transformer编码器
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        self.cls_head = nn.Linear(d_model, 2)  # 分类头
        
    def forward(self, x, attention_mask=None):
        # 添加位置编码
        seq_len = x.size(1)
        pos = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.pos_encoder(pos)
        
        # 通过Transformer编码器
        if attention_mask is not None:
            # 创建注意力掩码
            mask = (attention_mask == 0).unsqueeze(1).repeat(1, seq_len, 1)
            output = self.transformer_encoder(x, src_key_padding_mask=attention_mask)
        else:
            output = self.transformer_encoder(x)
            
        # CLS token用于分类
        cls_output = self.cls_head(output[:, 0, :])
        return cls_output

BERT预训练任务

BERT的预训练包含两个主要任务:

  1. 掩码语言模型(MLM):随机掩盖输入序列中的部分token,让模型预测被掩盖的token
  2. 下一句预测(NSP):判断两个句子是否为连续的上下文

下面是实现MLM任务的示例代码:

# BERT掩码语言模型任务
class MLMHead(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.dense = nn.Linear(d_model, d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.decoder = nn.Linear(d_model, vocab_size)
        
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.nn.functional.gelu(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        logits = self.decoder(hidden_states)
        return logits

# 添加MLM头到BERT
class BERTWithMLM(BERT):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=12):
        super().__init__(vocab_size, d_model, nhead, num_layers)
        self.mlm_head = MLMHead(d_model, vocab_size)
        
    def forward(self, x, attention_mask=None, masked_positions=None):
        # 获取Transformer输出
        hidden_states = self.transformer_encoder(x)
        
        # 如果指定了掩码位置,则只预测这些位置
        if masked_positions is not None:
            masked_hidden = hidden_states[torch.arange(hidden_states.size(0)), masked_positions]
            logits = self.mlm_head(masked_hidden)
        else:
            logits = self.mlm_head(hidden_states)
            
        return logits

GPT实现:自回归Transformer的应用

GPT(Generative Pre-trained Transformer)是OpenAI提出的生成式预训练模型,它使用Transformer的解码器部分,通过自回归方式生成文本。

GPT模型结构

与BERT不同,GPT只使用Transformer的解码器部分,并采用了因果掩码(Causal Mask)确保预测时只能看到前面的token。下面是GPT基础结构的实现:

# GPT基础结构实现
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model=768, nhead=12, num_layers=12):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Embedding(1024, d_model)
        
        # 创建Transformer解码器层
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=3072,
            batch_first=True
        )
        
        # 创建Transformer解码器
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers
        )
        
        self.generator = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        # 添加位置编码
        seq_len = x.size(1)
        pos = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.pos_encoder(pos)
        
        # 创建因果掩码
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=x.device)
        
        # 通过Transformer解码器
        output = self.transformer_decoder(x, x, tgt_mask=mask)
        logits = self.generator(output)
        return logits

GPT文本生成

GPT的主要应用是文本生成,下面是使用GPT进行文本生成的示例:

# GPT文本生成
def generate_text(model, start_text, tokenizer, max_length=100, temperature=1.0):
    model.eval()
    input_ids = tokenizer.encode(start_text, return_tensors='pt')
    
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            next_token_logits = outputs[0, -1, :] / temperature
            next_token_id = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
            
            if next_token_id.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# 使用示例
# generated_text = generate_text(gpt_model, "今天天气真好", tokenizer)
# print(generated_text)

模型调优与实践技巧

参数调优建议

  1. 学习率:Transformer模型通常使用较小的学习率(如1e-4到5e-5)
  2. 批次大小:尽可能使用大批次,如受内存限制可使用梯度累积
  3. 优化器:推荐使用AdamW优化器,权重衰减设为0.01

训练技巧

  1. 预热学习率:使用学习率预热策略可以稳定训练初期的梯度
  2. 梯度裁剪:对梯度进行裁剪可以防止梯度爆炸,通常阈值设为1.0
  3. 混合精度训练:使用PyTorch的AMP(自动混合精度)功能加速训练
# 使用AMP进行混合精度训练
scaler = torch.cuda.amp.GradScaler()

for inputs, labels in dataloader:
    optimizer.zero_grad()
    
    with torch.cuda.amp.autocast():
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

总结与展望

Transformer模型已经成为自然语言处理领域的基础架构,PyTorch提供了灵活高效的Transformer实现,使得我们可以轻松构建BERT、GPT等先进模型。随着硬件和算法的不断发展,Transformer模型的应用范围也在不断扩大,从NLP到计算机视觉、语音识别等领域都取得了突破性进展。

官方文档:docs/source/modules/nn.rst

希望本文能够帮助你更好地理解和使用PyTorch中的Transformer模块,如果你有任何问题或建议,欢迎在评论区留言讨论!记得点赞、收藏本文,关注作者获取更多AI技术分享。

下一期我们将介绍如何使用PyTorch Lightning加速Transformer模型的训练和部署,敬请期待!

【免费下载链接】pytorch Python 中的张量和动态神经网络,具有强大的 GPU 加速能力 【免费下载链接】pytorch 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值