任务目标:
通过RNN网络实现分类任务,即输入人名,输出此人名是哪个国家的。
数据集:
原始数据如下:
fairseq预处理:
cd到names同级目录执行下列命令,生成同级目录下names-bin文件夹。其中包含输入和输出的dict和每个split的input和label。
fairseq-preprocess \
--trainpref names/train --validpref names/valid --testpref names/test \
--source-lang input --target-lang label \
--destdir names-bin --dataset-impl raw
注册自定义模型:
建立新的文件fairseq/models/rnn_classifier.py
import torch
import torch.nn as nn
# 定义正常的RNN模型
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
from fairseq.models import BaseFairseqModel, register_model
# 注册自定义名字为'rnn_classifier'的模型
@register_model('rnn_classifier')
class FairseqRNNClassifier(BaseFairseqModel):
@staticmethod
def add_args(parser):
parser.add_argument(
'--hidden-dim', type=int, metavar='N',
help='dimensionality of the hidden state',
)
@classmethod
def build_model(cls, args, task):
# fairseq通过调用build_model来实现模型的搭建
# 实例化RNN
rnn = RNN(
input_size=len(task.source_dictionary),
hidden_size=args.hidden_dim,
output_size=len(task.target_dictionary),
)
# return实例化后的rnn,具体需要传入几个参数看__init__
return FairseqRNNClassifier(
rnn=rnn,
input_vocab=task.source_dictionary,
)
def __init__(self, rnn, input_vocab):
super(FairseqRNNClassifier, self).__init__()
self.rnn = rnn
self.input_vocab = input_vocab
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
def forward(self, src_tokens, src_lengths):
#forward中传入的参数由Task中每个minibatch的'net_input'决定
bsz, max_src_len = src_tokens.size()
hidden = self.rnn.initHidden()
hidden = hidden.repeat(bsz, 1) # expand for batched inputs
hidden = hidden.to(src_tokens.device) # move to GPU
for i in range(max_src_len):
input = self.one_hot_inputs[src_tokens[:, i].long()]
output, hidden = self.rnn(input, hidden)
return output
from fairseq.models import register_model_architecture
# 注册模型的参数'pytorch_tutorial_rnn'
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
def pytorch_tutorial_rnn(args):
args.hidden_dim = getattr(args, 'hidden_dim', 128)
注册新任务Task
建立新的任务文件fairseq/tasks/simple_classification.py
import os
import torch
from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import LegacyFairseqTask, register_task
# 注册名叫'simple_classification'的任务
@register_task('simple_classification')
class SimpleClassificationTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
parser.add_argument('data', metavar='FILE',
help='file prefix for data')
parser.add_argument('--max-positions', default=1024, type=int,
help='max input length')
@classmethod
def setup_task(cls, args, **kwargs):
input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
print('| [input] dictionary: {} types'.format(len(input_vocab)))
print('| [label] dictionary: {} types'.format(len(label_vocab)))
return SimpleClassificationTask(args, input_vocab, label_vocab)
def __init__(self, args, input_vocab, label_vocab):
super().__init__(args)
self.input_vocab = input_vocab
self.label_vocab = label_vocab
def load_dataset(self, split, **kwargs):
prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
sentences, lengths = [], []
with open(prefix + '.input', encoding='utf-8') as file:
for line in file:
sentence = line.strip()
# Tokenize the sentence, splitting on spaces
tokens = self.input_vocab.encode_line(
sentence, add_if_not_exist=False,
)
sentences.append(tokens)
lengths.append(tokens.numel())
labels = []
with open(prefix + '.label', encoding='utf-8') as file:
for line in file:
label = line.strip()
labels.append(
# Convert label to a numeric ID.
torch.LongTensor([self.label_vocab.add_symbol(label)])
)
assert len(sentences) == len(labels)
print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
self.datasets[split] = LanguagePairDataset(
src=sentences,
src_sizes=lengths,
src_dict=self.input_vocab,
tgt=labels,
tgt_sizes=torch.ones(len(labels)), # targets have length 1
tgt_dict=self.label_vocab,
left_pad_source=False,
# Since our target is a single class label, there's no need for
# teacher forcing. If we set this to ``True`` then our Model's
# ``forward()`` method would receive an additional argument called
# *prev_output_tokens* that would contain a shifted version of the
# target sequence.
input_feeding=False,
)
def max_positions(self):
# The source should be less than *args.max_positions* and the "target"
# has max length 1.
return (self.args.max_positions, 1)
@property
def source_dictionary(self):
return self.input_vocab
@property
def target_dictionary(self):
return self.label_vocab
完成以上工作就可以进行fairseq-train了
fairseq-train names-bin \
--task simple_classification \
--arch pytorch_tutorial_rnn \
--optimizer adam --lr 0.001 --lr-shrink 0.5 \
--max-tokens 1000 --save-dir checkpoints/rnn_classification
跑起来之后就会在checkpoints/rnn_classification文件夹生成模型权重文件。编写evaluate就可以进行验证了。