Transformer 模型训练
学习目标
本课程旨在让学员掌握 Transformer 模型训练的流程。
相关知识点
- Transformer模型训练
学习内容
1. Transformer模型训练
1.1 定义模型
- 设置 Matplotlib 绘图的显示方式
%matplotlib inline
-
在本课程中,我们将在语言建模任务上训练一个
TransformerEncoder模型。语言建模任务的目标是为给定的一个单词(或一组单词)接续在某一单词序列之后的可能性分配一个概率。 -
首先,将一组标记(tokens)传入嵌入层(embedding layer),随后通过位置编码层(positional encoding layer)来考虑单词的顺序。(
TransformerEncoder由多个TransformerEncoderLayer层组成。) -
除了输入序列之外,还需要一个方形的注意力掩码(attention mask),因为
TransformerEncoder中的自注意力层self-attention layers只允许关注序列中先前的位置。对于语言建模任务,未来位置的任何标记都应该被掩码处理。 -
为了在输出单词上生成概率分布,
TransformerEncoder模型的输出会先经过一个线性层,然后再通过对数 softmax 函数。
import math
from typing import Tuple
import torch
import torch_npu
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
"""
Args:
src: Tensor, shape [seq_len, batch_size]
src_mask: Tensor, shape [seq_len, seq_len]
Returns:
output Tensor of shape [seq_len, batch_size, ntoken]
"""
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
def generate_square_subsequent_mask(sz: int) -> Tensor:
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
PositionalEncoding模块会将有关序列中标记(tokens)相对位置或绝对位置的一些信息融入其中。位置编码的维度与词嵌入(embeddings)的维度相同,这样二者就可以相加。在这里,我们使用不同频率的正弦(sine)和余弦(cosine)函数来生成位置编码。
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
1.2 加载并处理数据
- 本课程使用
Wikitext - 2数据集。 - 词汇表
vocab对象是基于训练数据集构建的,用于将标记tokens转换为数值化的张量。在维基文本 - 2 数据集中,稀有标记用 表示。 - 对于一维的序列数据,
batchify()函数会将数据排列成 batch_size 列。如果数据不能恰好被分成 batch_size 列,那么数据会被截断以适配。- 例如,以字母表作为数据(总长度为 26),且 batch_size = 4 时,我们会将字母表划分为 4 个长度为 6 的序列:
[ABC…XYZ]⇒[[ABCDEF][GHIJKL][MNOPQR][STUVWX]] \begin{align}\begin{bmatrix} \text{A} & \text{B} & \text{C} & \ldots & \text{X} & \text{Y} & \text{Z} \end{bmatrix} \Rightarrow \begin{bmatrix} \begin{bmatrix}\text{A} \\ \text{B} \\ \text{C} \\ \text{D} \\ \text{E} \\ \text{F}\end{bmatrix} & \begin{bmatrix}\text{G} \\ \text{H} \\ \text{I} \\ \text{J} \\ \text{K} \\ \text{L}\end{bmatrix} & \begin{bmatrix}\text{M} \\ \text{N} \\ \text{O} \\ \text{P} \\ \text{Q} \\ \text{R}\end{bmatrix} & \begin{bmatrix}\text{S} \\ \text{T} \\ \text{U} \\ \text{V} \\ \text{W} \\ \text{X}\end{bmatrix} \end{bmatrix}\end{align} [ABC…XYZ]⇒ABCDEFGHIJKLMNOPQRSTUVWX
- 批量处理能够实现更高效的并行计算。不过,批量处理意味着模型会独立处理每一列;例如,在上述示例中,模型无法学习到
G和F之间的依赖关系。
# 获取数据并解压
!wget https://model-community-picture.obs.cn-north-4.myhuaweicloud.com/ascend-zone/notebook_datasets/c2a45b82e85e11ef8400fa163edcddae/wikitext-2.tar.gz
!tar -zxvf wikitext-2.tar.gz
import torch
from collections import Counter
import pandas as pd
from torch.utils.data import Dataset, DataLoader
# 手动定义分词器
def basic_english_tokenizer(text):
return text.strip().split()
# 手动构建词汇表
def build_vocab(file_path):
df = pd.read_parquet(file_path)
counter = Counter()
for _, row in df.iterrows():
# 'text'为数据集中文本列名 ,请根据实际情况修改
text = str(row['text'])
counter.update(basic_english_tokenizer(text))
vocab = {word: idx for idx, (word, _) in enumerate(counter.most_common())}
vocab['<unk>'] = len(vocab)
default_index = vocab['<unk>']
return vocab, default_index
# 读取数据并转换为张量
def data_process(file_path, vocab, default_index):
df = pd.read_parquet(file_path)
data = []
for _, row in df.iterrows():
# 'text'为数据集中文本列名 ,请根据实际情况修改
text = str(row['text'])
tokens = basic_english_tokenizer(text)
token_ids = [vocab.get(token, default_index) for token in tokens]
if token_ids:
data.extend(token_ids)
return torch.tensor(data, dtype=torch.long)
# 数据集文件路径
train_file = 'wikitext-2/data/train-00000-of-00001.parquet'
val_file = 'wikitext-2/data/validation-00000-of-00001.parquet'
test_file = 'wikitext-2/data/test-00000-of-00001.parquet'
# 构建词汇表
vocab, default_index = build_vocab(train_file)
# 处理数据
train_data = data_process(train_file, vocab, default_index)
val_data = data_process(val_file, vocab, default_index)
test_data = data_process(test_file, vocab, default_index)
device = torch.device('npu:0')
def batchify(data: torch.Tensor, bsz: int) -> torch.Tensor:
"""Divides the data into bsz separate sequences, removing extra elements
that wouldn't cleanly fit.
Args:
data: Tensor, shape [N]
bsz: int, batch size
Returns:
Tensor of shape [N // bsz, bsz]
"""
seq_len = data.size(0) // bsz
data = data[:seq_len * bsz]
data = data.view(bsz, seq_len).t().contiguous()
return data.to(device)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
print("数据处理完成")
get_batch()函数为 Transformer 模型生成一对输入 - 目标序列。它会将源数据细分为长度为bptt的数据块。- 需要注意的是,这些数据块是沿着第 0 维划分的,这与 Transformer 模型中的 S 维度是一致的。而批次维度 N 则是沿着第 1 维。
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
"""
Args:
source: Tensor, shape [full_seq_len, batch_size]
i: int
Returns:
tuple (data, target), where data has shape [seq_len, batch_size] and
target has shape [seq_len * batch_size]
"""
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target
- 下面定义了模型的超参数。
- 词汇表大小等于词汇表对象的长度。
ntokens = len(vocab) # size of vocabulary
emsize = 200 # embedding dimension
d_hid = 200 # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # number of heads in nn.MultiheadAttention
dropout = 0.2 # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)
1.3 模型训练与评估
- 我们使用 CrossEntropyLoss (交叉熵损失函数)搭配 SGD (随机梯度下降)优化器。学习率初始设置为 5.0,并遵循 StepLR 学习率调度策略。
- 在训练过程中,我们使用 nn.utils.clip_grad_norm_ 来防止梯度爆炸。
import copy
import time
import math
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import Tensor
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
def train(model: nn.Module) -> None:
model.train() # turn on train mode
total_loss = 0.
log_interval = 200
start_time = time.time()
src_mask = generate_square_subsequent_mask(bptt).to(device)
num_batches = len(train_data) // bptt
# 使用 tqdm 包装迭代器,添加进度条
with tqdm(range(0, train_data.size(0) - 1, bptt), desc=f'Epoch {epoch} Training', unit='batch') as pbar:
for batch, i in enumerate(pbar):
data, targets = get_batch(train_data, i)
batch_size = data.size(0)
if batch_size != bptt: # only on last batch
src_mask = src_mask[:batch_size, :batch_size]
output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
ppl = math.exp(cur_loss)
# 使用 tqdm.write() 输出日志
pbar.set_postfix({
'loss': f'{cur_loss:.4f}',
'ppl': f'{ppl:.2f}',
'lr': f'{lr:.2f}',
'ms/batch': f'{ms_per_batch:.2f}'
})
total_loss = 0
start_time = time.time()
# 更新进度条的额外信息
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
def evaluate(model: nn.Module, eval_data: Tensor) -> float:
model.eval() # turn on evaluation mode
total_loss = 0.
src_mask = generate_square_subsequent_mask(bptt).to(device)
with torch.no_grad():
for i in range(0, eval_data.size(0) - 1, bptt):
data, targets = get_batch(eval_data, i)
batch_size = data.size(0)
if batch_size != bptt:
src_mask = src_mask[:batch_size, :batch_size]
output = model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += batch_size * criterion(output_flat, targets).item()
return total_loss / (len(eval_data) - 1)
- 逐轮(epoch)进行训练循环。如果验证损失是目前为止我们所见到的最优值,则保存模型。在每一轮训练结束后调整学习率。
best_val_loss = float('inf')
epochs = 1
best_model = None
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train(model)
val_loss = evaluate(model, val_data)
val_ppl = math.exp(val_loss)
elapsed = time.time() - epoch_start_time
print('-' * 89)
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
print('-' * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = copy.deepcopy(model)
scheduler.step()
out:
-----------------------------------------------------------------------------------------
| end of epoch 1 | time: 49.59s | valid loss 6.53 | valid ppl 685.39
-----------------------------------------------------------------------------------------
- 模型评估
test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
f'test ppl {test_ppl:8.2f}')
print('=' * 89)
=========================================================================================
| End of training | test loss 6.52 | test ppl 678.52
=========================================================================================
9万+

被折叠的 条评论
为什么被折叠?



