下面我们将基于之前的神经机器翻译代码,使用更复杂的多头注意力机制(Multi - Head Attention)并增大模型容量来实现一个更强大的机器翻译模型。这里我们使用 PyTorch 框架,以英德翻译为例,数据集依然采用 torchtext
中的 Multi30k
数据集。
1. 安装必要的库
pip install torch torchtext spacy
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm
2. 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import random
import math
import time
# 设置随机种子以保证结果可复现
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# 加载英语和德语的分词器
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')
# 定义分词函数
def tokenize_de(text):
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
# 定义字段
SRC = Field(tokenize=tokenize_de,
init_token='<sos>',
eos_token='<eos>',
lower=True)
TRG = Field(tokenize=tokenize_en,
init_token='<sos>',
eos_token='<eos>',
lower=True)
# 加载数据集
train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'),
fields=(SRC, TRG))
# 构建词汇表
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)
# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建迭代器
BATCH_SIZE = 128
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size=BATCH_SIZE,
device=device)
# 定义多头注意力机制
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self