深度学习LSTM公式原理

一、介绍

普通的RNN信息不能长久传播,因为结尾较远的信息被稀释的比较厉害,所以引入LSTM,LSTM 是一种特殊的循环神经网络(RNN),专门设计用于解决传统RNN在处理长序列数据时的核心问题:长期依赖(Long-Term Dependencies)。

与RNN相比,LSTM在计算隐藏层时,会包含当前时刻的日记信息。 LSTM通过门控机制(Gates)细胞状态(Cell State),显式控制信息的保留与遗忘,从而解决长期依赖问题。

门机制是通过学习对原权重进行更新。

遗忘门:决定细胞状态中哪些信息需要被丢弃(通过Sigmoid函数输出0~1之间的值)。

传入门:要不要把重要的传入,决定当前输入信息中哪些需要更新到细胞状态。

输出门:比如动词要用单数还是复数。决定当前细胞状态中哪些信息需要输出到隐藏状态。

 

公式详解:

(1) 遗忘门(Forget Gate)

ft=σ(Wf⋅[ht−1,xt]+bf)

  • 符号说明
    • σ:sigmoid函数(输出0~1,表示保留比例)
    • Wf​:遗忘门权重矩阵
    • bf​:遗忘门偏置
  • 物理意义:根据当前输入xt​和上一状态ht−1​,决定细胞状态Ct−1​中哪些信息需要遗忘(接近0)或保留(接近1)。
(2) 输入门(Input Gate)

it=σ(Wi⋅[ht−1,xt]+bi)

C~t=tanh⁡(WC⋅[ht−1,xt]+bC)

  • 符号说明
    • it​:输入门的开关(控制新信息流入)
    • C~t​:候选细胞状态(新信息的原始提案)
  • 物理意义
    • 输入门it决定更新多少新信息
    • C~t通过tanh⁡tanh生成候选记忆(范围-1~1)
(3) 更新细胞状态

Ct=ft⊙Ct−1+it⊙C~t

  • 符号说明
    • ⊙:逐元素相乘(Hadamard积)
  • 物理意义
    • 第一部分ft⊙Ct−1:遗忘旧信息
    • 第二部分it⊙C~t:添加新信息
    • 最终得到更新后的细胞状态Ct
(4) 输出门(Output Gate)

ot=σ(Wo⋅[ht−1,xt]+bo)

ht=ot⊙tanh⁡(Ct)

  • 物理意义
    • ot​决定输出多少细胞状态信息
    • tanh⁡(Ct)将细胞状态压缩到-1~1范围
    • 最终隐藏状态ht是过滤后的输出

 输入序列 → [遗忘门] → 丢弃部分旧记忆 →  [输入门] → 添加新候选记忆 → 更新细胞状态 → 
[输出门] → 生成当前隐藏状态

与RNN对比:

更新公式Ct=ft⊙Ct−1+...包含加法操作,梯度通过多个时间步累加而非连乘,缓解指数衰减 

二、代码

相比与之前的RNN代码,只需要把RNN变为LSTM即可:

class LSTM(nn.Module):
    def __init__(self, embedding_dim=16, hidden_dim=64, vocab_size=vocab_size, num_layers=1, bidirectional=False):
        super(LSTM, self).__init__()
        self.embeding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)
        self.layer = nn.Linear(hidden_dim * (2 if bidirectional else 1), hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        # [bs, seq length]
        x = self.embeding(x)
        # [bs, seq length, embedding_dim] -> shape [bs, seq length,hidden_dim]
        seq_output, (hidden, cell) = self.lstm(x)
        # print(f'seq_output.shape{seq_output.shape}')
        # print(f'hidden.shape{hidden.shape}') #最后一个时间步的输出
        # print(f'cell.shape{cell.shape}') #最后一个时间步的cell state
        # print(seq_output[:, -1, :].squeeze()==hidden.squeeze()) #squeeze() 去掉轴的尺寸为1的哪个轴
        # print(seq_output[:, -1, :].squeeze()==cell.squeeze())
        x = seq_output[:, -1, :]
        # 取最后一个时间步的输出 (这也是为什么要设置padding_first=True的原因)
        x = self.layer(x)
        x = self.fc(x)
        return x
    
sample_inputs = torch.randint(0, vocab_size, (2, 128))
    
print("{:=^80}".format(" 一层单向 LSTM "))       
for key, value in LSTM().named_parameters():
    print(f"{key:^40}paramerters num: {np.prod(value.shape)}")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

何仙鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值