基于pytorch的Seq2Seq的实现

本文深入探讨Seq2Seq模型,它由编码器和解码器两部分组成,两者均由循环神经网络构建。编码器接收序列输入并输出隐藏状态,解码器利用该状态生成预测序列。代码示例展示了PyTorch中EncoderRNN和DecoderRNN的实现,以及如何将它们组合成Seq2Seq模型进行序列预测。
  1. 结构
    这是一种经典的Seq2Seq结构,由两部分组成,Encoder 和 Decoder(编码器和解码器), 这两个部分的构造几乎一模一样,都是由循环神经网络组成,层数,隐藏层维度都是一样的。
    它的原理是一个序列输入到编码器之后,编码器输出最后的隐藏状态(h)到解码器, h包含了序列的前后信息,然后通过解码器对h处理,逐步生成相应的预测序列。
    在这里插入图片描述

  2. Encoder
    以下是Encoder的一段基本代码,由最基本的RNN组成:
    需要弄清楚它矩阵的变换。

class EncoderRNN(nn.Module):

    def __init__(self,input_size, hidden_size, n_layers, emb_size, dropout=0.5, bidirections=False):
        """
        :param input_size:源序列的编码总数, 例如英文26个字母,那么这里就是26
        :param hidden_size:隐藏层参数
        :param n_layers: 网络层数
        :param emb_size: embedding维度
        :param dropout: dropout参数
        :param bidierctions: 双向还是单向
        """
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.emb_size = emb_size
        self.dropout = dropout
        self.bidirecions = bidirections

        self.Embedding = nn.Embedding(self.input_size, self.emb_size)  # embedding编码
        self.RNN = nn.RNN(self.emb_size, self.hidden_size, self.n_layers, dropout=self.dropout, bidirectional=self.bidirecions)

    def forward(self, src):
        # src [seq_len, batch]

        embedded = self.Embedding(src)

        #embedded [seq_len, batch, emb_size]

        out, h = self.RNN(embedded)
        
        # out [seq_len, batch, direction*hidden_size]
        # h [n_layers*directions, batch, hidden_size]

        return out, h

然后我们任意设定一个序列输入到网络中,pytorch中的循环神经网络喂入的数据形式比较特殊,它的batch_size并不是在第一维,而是在第二维,因此我们的数据也要做相应的调整。

if __name__ == "__main__":
    INPUT_SIZE = 10
    HIDDEN_SIZE = 20
    N_LAYERS = 2
    EMB_SIZE = 10
    model = EncoderRNN(INPUT_SIZE, HIDDEN_SIZE, N_LAYERS, EMB_SIZE)
    x = torch.tensor([[1, 2, 3, 4, 9], [2, 2, 3, 
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值