在fairseq中使用新注册的模型、损失函数等

文章详细介绍了如何在fairseq框架下创建一个名为my_dir的目录,包含SimpleLSTM模型的编码器和解码器类,以及交叉熵损失函数。通过`register_model`和`register_model_architecture`注册模型和架构,使得模型能在fairseq中使用。最后,展示了使用`fairseq-train`命令进行模型训练的示例。

在fairseq的目录中创建一个文件夹my_dir

/fairseq/my_dir/
└── __init__.py
└── models
	└── simple_lstm.py
└── criterions
	└── cross_entropy.py

在simple_lstm.py中已经注册好模型和架构(tutorial_simple_lstm)
在fairseq中注册模型、架构以及criterion等,见fairseq 官方文档

import torch
from torch import nn
from fairseq import utils
from fairseq.models import FairseqEncoder, FairseqDecoder, FairseqEncoderDecoderModel
from fairseq.models import register_model, register_model_architecture

class SimpleLSTMEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1):
        super().__init__(dictionary)
        self.args = args

        self.embedding = nn.Embedding(
            num_embeddings=len(dictionary),
            embedding_dim=embed_dim,
            padding_idx=dictionary.pad(),
        )
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=False,
            batch_first=True,
        )

    def forward(self, src_tokens, src_lengths):
        # 将padding变到右边
        if self.args.left_pad_source:
            src_tokens = utils.convert_padding_direction(
                src_tokens,
                padding_idx=self.dictionary.pad(),
                left_to_right=True,
            )
        
        x = self.embedding(src_tokens)

        x = self.dropout(x)
        # 将序列打包到PackedSequence对象中以提供给LSTM
        x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.cpu(), batch_first=True)

        outputs,(final_hidden, final_cell) = self.lstm(x)

        return {
   
   
            'final_hidden': final_hidden.squeeze(0)
        }
    
    def reorder_encoder_out(self, encoder_out, new_order):
        '''
        encoder_out是从forward函数中的返回值
        new_order(LongTensor)是想要的顺序
        '''
        final_hidden = encoder_out['final_hidden']

        return {
   
   
            'final_hidden':final_hidden.index_select(0, new_order)
        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值