pytorch 学习 | 使用pytorch动手实现LSTM模块

本文介绍了LSTM在网络中的作用,包括输入、门控机制和输出,并通过PyTorch展示了LSTM模块的实现过程,对比了自定义实现与官方实现的差异,虽然存在细微差别,但结果保持一致。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

LSTM 简介

LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出。其结构如下:
在这里插入图片描述
在这里插入图片描述
上述公式不做解释,我们只要大概记得以下几个点就可以了:

  1. 当前时刻LSTM模块的输入有来自当前时刻的输入值,上一时刻的输出值,输入值和隐含层输出值,就是一共有四个输入值,这意味着一个LSTM模块的输入量是原来普通全连接层的四倍左右,计算量多了许多。
  2. 所谓的门就是前一时刻的计算值输入到sigmoid激活函数得到一个概率值,这个概率值决定了当前输入的强弱程度。 这个概率值和当前输入进行矩阵乘法得到经过门控处理后的实际值。
  3. 门控的激活函数都是sigmoid,范围在(0,1),而输出输出单元的激活函数都是tanh,范围在(-1,1)。

Pytorch实现如下:

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math

class NaiveLSTM(nn.Module):
    """Naive LSTM like nn.LSTM"""
    def __init__(self, input_size: int, hidden_size: int):
        super(NaiveLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # input gate
        self.w_ii = Parameter(Tensor(hidden_size, input_size))
        self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
        self.b_ii = Parameter(Tensor(hidden_size, 1))
        self.b_hi = Parameter(Tensor(hidden_size, 1))

        # forget gate
        self.w_if = Parameter(Tensor(hidden_size, input_size))
        self.w_hf = Parameter(Tensor
评论 29
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值