循环神经网络(二):多种结构的RNN

循环神经网络(二):多种结构的RNN

一对一的RNN:

一个输入对应一个输出。
在这里插入图片描述

import torch
from torch import nn

class One2OneRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.cell = nn.RNNCell(input_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, h=None):
        h = self.cell(x, h)
        y = self.fc_out(h)
        return y, h

if __name__ == '__main__':
    x = torch.randn(1,100)
    model = One2OneRNN(100, 100)
    y,h = model(x)
    print(y.shape)
    print(h.shape)

一对多的RNN:

当输出多个值的时候,例如 一对多 或 多对多 的时候,循环多少次?什么时候该跳出循环?是我们需要考虑的 一般设置两个条件

  1. 设置一个最大循环次数,不得超过这个循环次数
  2. 当输出值满足某个条件的时候,跳出循环
    在这里插入图片描述
import torch
from torch import nn

class One2ManyRNN(nn.Module):
    # max_iter: 循环神经网络,可以循环的最大次数
    # end_token: 结束标志,当模型输出 end_token 时,我们人为结束循环
    def __init__(self, input_size, hidden_size, end_token=0, max_iter=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = nn.RNNCell(input_size, hidden_size)
        self.max_iter = max_iter
        self.fc_out = nn.Linear(hidden_size, hidden_size)
        # 结束标志
        self.end_token = torch.ones(self.hidden_size) * end_token

    # x (input_size)
    def forward(self, x, h=None):
        # 初始化 h
        if h is None:
            h = torch.zeros(self.hidden_size)

        outputs = []

        # 循环更新隐藏状态并输出内容
        for i in range(self.max_iter):
            h = self.cell(x, h)
            out = self.fc_out(h)
            outputs.append(out)

            # 如果输出为结束标志,则跳出循环
            if torch.allclose(out, self.end_token, atol=1e-3):
                break
        # 堆叠
        outputs = torch.stack(outputs)
        return outputs, h


if __name__ == '__main__':
    import torch

    x = torch.rand(10)
    model = One2ManyRNN(10, 20)
    y, h = model(x)
    print(y.shape)
    print(h.shape)

多对一的RNN:

多个输入对应一个输出。
在这里插入图片描述

import torch
from torch import nn

# 多对一
class Many2One(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = nn.RNNCell(input_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, hidden_size)

    # x (N, L, input_size)
    def forward(self, x, h=None):
        N, L, input_size = x.shape
        # 初始化 h
        if h is None:
            h = torch.zeros(N, self.hidden_size)
        # 循环对输入进行编码,但不输出
        for i in range(L):
            h = self.cell(x[:, i], h)
        # 循环编码完输入数据后,再输出
        out = self.fc_out(h)
        return out, h


if __name__ == '__main__':
    model = Many2One(10, 20)
    x = torch.randn(5, 3, 10)
    y, h = model(x)
    print(y.shape)
    print(h.shape)

多对多的RNN:

多个输入对应多个输出。
在这里插入图片描述

import torch
from torch import nn


class Many2ManyRNN(nn.Module):
    # max_iter: 输出的最大循环次数
    def __init__(self, input_size, hidden_size, end_token=0, max_iter=10):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.max_iter = max_iter
        self.cell = nn.RNNCell(input_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, hidden_size)
        self.end_token = torch.ones(hidden_size) * end_token

    # x (L, input_size)
    def forward(self, x, h=None):
        L, input_size = x.shape
        if h is None:
            h = torch.zeros(self.hidden_size)

        # 循环编码,编码输入序列,在此过程中不输出内容
        for i in range(L):
            # 更新隐藏状态
            h = self.cell(x[i], h)

        outputs = []

        # 循环输出
        for i in range(self.max_iter):
            # 先输出,再调用 cell
            out = self.fc_out(h)
            outputs.append(out)
            # 判断输出结果是否是停止符号
            if torch.allclose(out, self.end_token, atol=1e-3):
                break
            # 更新隐藏状态
            h = self.cell(torch.zeros(input_size), h)

        # 堆叠
        outputs = torch.stack(outputs)

        return outputs, h


if __name__ == '__main__':
    model = Many2ManyRNN(10, 20)
    x = torch.rand(3, 10)
    y, h = model(x)
    print(y.shape)
    print(h.shape)
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序员miki

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值