【动手学习pytorch笔记】25.长短期记忆LSTM

本文深入介绍了LSTM(长短期记忆网络)的工作原理,包括输入门、遗忘门、输出门和记忆单元的计算,并与GRU(门控循环单元)进行了对比。此外,还提供了LSTM模型的PyTorch简易实现,包括从头开始的实现和使用内置LSTM层的实现。最后,展示了训练过程和性能评估。

LSTM

理论

输入门(决定是否适用隐藏状态) It=σ(XtWxi+Ht−1Whi+bi)I_t = \sigma(X_tW_{xi}+H_{t-1}W_{hi}+b_i)It=σ(XtWxi+Ht1Whi+bi)

遗忘门(将值朝0减少) Ft=σ(XtWxf+Ht−1Whf+bf)F_t = \sigma(X_tW_{xf}+H_{t-1}W_{hf}+b_f)Ft=σ(XtWxf+Ht1Whf+bf)

输出门(决定是否适用隐藏状态) Ot=σ(XtWxo+Ht−1Who+bo)O_t = \sigma(X_tW_{xo}+H_{t-1}W_{ho}+b_o)Ot=σ(XtWxo+Ht1Who+bo)

候选记忆单元(这不就是RNN的H么) Ct~=tanh(XtWxc+Ht−1Whc+bc)\tilde{C_t} =tanh(X_tW_{xc}+ H_{t-1}W_{hc}+b_c)Ct~=tanh(XtWxc+Ht1Whc+bc)

记忆单元 Ct=Ft⋅Ct−1+It⋅Ct~C_t =F_t \cdot C_{t-1} + I_t \cdot \tilde{C_t}Ct=FtCt1+ItCt~

隐藏状态 Ht=Ot⋅tanh(Ct)H_t =O_t \cdot tanh(C_t)Ht=Ottanh(Ct)

LSTM的记忆单元和和GRU的记忆单元相比,上一步Ct−1C_{t-1}Ct1和这一步Ct~\tilde{C_t}Ct~都可以有权重,不像GRU是 Z 和 (1-Z),当然也可以都没有。

隐藏状态tanh(Ct)tanh(C_t)tanh(Ct)是因为CtC_tCt是[-2, 2],要重新变回[-1, 1]

输出门如果是0,意味着当前的HtH_tHt和之前的信息都不要了,下一个时序看到的是和之前完全无关的

在这里插入图片描述

代码

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

turple终于又用了,这里一个C一个H

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 19744.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travellerbut becarfally but s i the peosterto timey itwing 

在这里插入图片描述

简易实现

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 132211.4 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

在这里插入图片描述

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值