Fairseq官方文档RNN分类例子

任务目标:

 通过RNN网络实现分类任务,即输入人名,输出此人名是哪个国家的。

数据集:

 原始数据如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

fairseq预处理:

 cd到names同级目录执行下列命令,生成同级目录下names-bin文件夹。其中包含输入和输出的dict和每个split的input和label。

fairseq-preprocess \
  --trainpref names/train --validpref names/valid --testpref names/test \
  --source-lang input --target-lang label \
  --destdir names-bin --dataset-impl raw

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

注册自定义模型:

 建立新的文件fairseq/models/rnn_classifier.py

import torch
import torch.nn as nn

# 定义正常的RNN模型
class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)



from fairseq.models import BaseFairseqModel, register_model

# 注册自定义名字为'rnn_classifier'的模型
@register_model('rnn_classifier')
class FairseqRNNClassifier(BaseFairseqModel):

    @staticmethod
    def add_args(parser):

        parser.add_argument(
            '--hidden-dim', type=int, metavar='N',
            help='dimensionality of the hidden state',
        )

    @classmethod
    def build_model(cls, args, task):
    	# fairseq通过调用build_model来实现模型的搭建
    	# 实例化RNN
        rnn = RNN(

            input_size=len(task.source_dictionary),
            hidden_size=args.hidden_dim,
            output_size=len(task.target_dictionary),
        )
		# return实例化后的rnn,具体需要传入几个参数看__init__
        return FairseqRNNClassifier(
            rnn=rnn,
            input_vocab=task.source_dictionary,
        )

    def __init__(self, rnn, input_vocab):
        super(FairseqRNNClassifier, self).__init__()

        self.rnn = rnn
        self.input_vocab = input_vocab
        self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))

    def forward(self, src_tokens, src_lengths):
    	#forward中传入的参数由Task中每个minibatch的'net_input'决定

        bsz, max_src_len = src_tokens.size()

        hidden = self.rnn.initHidden()
        hidden = hidden.repeat(bsz, 1)  # expand for batched inputs
        hidden = hidden.to(src_tokens.device)  # move to GPU

        for i in range(max_src_len):

            input = self.one_hot_inputs[src_tokens[:, i].long()]
            output, hidden = self.rnn(input, hidden)
            
        return output


from fairseq.models import register_model_architecture

# 注册模型的参数'pytorch_tutorial_rnn'
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
def pytorch_tutorial_rnn(args):
    args.hidden_dim = getattr(args, 'hidden_dim', 128)

注册新任务Task

 建立新的任务文件fairseq/tasks/simple_classification.py

import os
import torch

from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import LegacyFairseqTask, register_task

# 注册名叫'simple_classification'的任务
@register_task('simple_classification')
class SimpleClassificationTask(LegacyFairseqTask):

    @staticmethod
    def add_args(parser):
        parser.add_argument('data', metavar='FILE',
                            help='file prefix for data')
        parser.add_argument('--max-positions', default=1024, type=int,
                            help='max input length')

    @classmethod
    def setup_task(cls, args, **kwargs):
        input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
        label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
        print('| [input] dictionary: {} types'.format(len(input_vocab)))
        print('| [label] dictionary: {} types'.format(len(label_vocab)))

        return SimpleClassificationTask(args, input_vocab, label_vocab)

    def __init__(self, args, input_vocab, label_vocab):
        super().__init__(args)
        self.input_vocab = input_vocab
        self.label_vocab = label_vocab

    def load_dataset(self, split, **kwargs):

        prefix = os.path.join(self.args.data, '{}.input-label'.format(split))

        sentences, lengths = [], []
        with open(prefix + '.input', encoding='utf-8') as file:
            for line in file:
                sentence = line.strip()

                # Tokenize the sentence, splitting on spaces
                tokens = self.input_vocab.encode_line(
                    sentence, add_if_not_exist=False,
                )

                sentences.append(tokens)
                lengths.append(tokens.numel())

        labels = []
        with open(prefix + '.label', encoding='utf-8') as file:
            for line in file:
                label = line.strip()
                labels.append(
                    # Convert label to a numeric ID.
                    torch.LongTensor([self.label_vocab.add_symbol(label)])
                )

        assert len(sentences) == len(labels)
        print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))

        self.datasets[split] = LanguagePairDataset(
            src=sentences,
            src_sizes=lengths,
            src_dict=self.input_vocab,
            tgt=labels,
            tgt_sizes=torch.ones(len(labels)),  # targets have length 1
            tgt_dict=self.label_vocab,
            left_pad_source=False,
            # Since our target is a single class label, there's no need for
            # teacher forcing. If we set this to ``True`` then our Model's
            # ``forward()`` method would receive an additional argument called
            # *prev_output_tokens* that would contain a shifted version of the
            # target sequence.
            input_feeding=False,
        )

    def max_positions(self):
        # The source should be less than *args.max_positions* and the "target"
        # has max length 1.
        return (self.args.max_positions, 1)

    @property
    def source_dictionary(self):
        return self.input_vocab

    @property
    def target_dictionary(self):
        return self.label_vocab

完成以上工作就可以进行fairseq-train了

fairseq-train names-bin \
  --task simple_classification \
  --arch pytorch_tutorial_rnn \
  --optimizer adam --lr 0.001 --lr-shrink 0.5 \
  --max-tokens 1000 --save-dir checkpoints/rnn_classification

跑起来之后就会在checkpoints/rnn_classification文件夹生成模型权重文件。编写evaluate就可以进行验证了。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值