9.2 长短期记忆网络

部署运行你感兴趣的模型镜像

变量模型存在长期信息保存短期输入缺失问题解决这一问题最早方法之一长短期记忆网络LSTM有许多门控循环单元一样属性有趣的长短期记忆网络设计门控循环单元稍微复杂一些却比门控循环单元出现了20

9.2.1 门控记忆

长短期记忆网络设计灵感来自计算机逻辑门长短期记忆网络引入了记忆或简称单元cell有些文献认为记忆状态一种特殊类型状态具有相同形状设计目的用于记忆附加信息为了控制记忆我们需要许多其中一个用来记忆输出条目我们称为输出门另一个用来决定何时数据读入记忆我们将其称为输入我们还需要一种机制重置记忆元内容, 遗忘门来管理这种设计动机门控循环单元相同能够通过专用机制决定什么时候记忆或者忽略状态输入

1 输入门遗忘门输出

就如门控循环单元一样当前时间输入前一个时间步状态作为数据送入短期记忆网络如图9-4所示3带有sigmoid激活函数连接处理计算输入门遗忘门输出门3都在(0,1)范围内

状态Ht-1

输入Xt 遗忘门Ft输入门It, 输出门Ot

9-4 长短期记忆网络模型输入门遗忘门输出门

我们来细化一下长短期记忆网络数学表达假设h隐藏单元批量大小n输入d因此输入Xt属于Rnxd前一个时间步状态Ht-1属于Rnxh相应时间步t定义如下输入门It属于Rnxh遗忘门Ft属于Rnxh输出门Ot属于Rnxh他们计算方法如下

It = sigma(XtWxi + Ht-1Whi +bi)

Ft = sigma(XtWxf + Ht-1Whf + bf)

Ot = sigma(XtWxo + Ht-1Who + bo)

2 候选记忆

因为还没有指定各种操作先介绍候选记忆元Ct属于Rnxh 计算上面描述3个门计算类似但是使用tanh函数作为激活函数函数值范围(-1,1)下面导出时间步t公式

Ct = tanh(XtWxc +Ht-1Whc +Bc)

其中Wxc属于RdxhWhc属于Rhxh是权重参数Bc属于Rlxh偏置参数

候选记忆元如图9-5所示

状态Ht-1 输入Xt

遗忘门Ft, 输入门It, 候选记忆元Ct 输出门Ot

带有激活函数连接层

3 记忆元

门控循环单元中有一种机制控制输入遗忘类似长短期记忆网络中有两个用于这样目的输入门It控制采用多少来自Ct新数据而遗忘门Ft控制保留多少过去记忆元Ct-1属于Rnxh内容使用按照元素乘法

Ct = Ft Ct-1 + It Ct

如果遗忘始终l输入始终0过去记忆元Ct-1时间保存并传递当前时间步引入这种设计为了缓解梯度消失问题并更好捕获序列的长距离依赖关系

这样我们就得到了计算记忆元数据如图9-6所示

4 状态

最后我们需要定义如何计算状态Ht属于Rnxh就是输出发挥作用地方长短期记忆网络中仅仅记忆元tanh门控版本这就确保了Ht始终区间(-1,1)

Ht = Ot tanh(Ct)

只要输出接近l我们就能有效将所有记忆信息传递预测部分对于输出门接近0我么只保留记忆所有信息不需要更新状态

9-7提供了数据流图形化演示

9.2.2 从零开始实现

实现长短期记忆网络8.5节中实验相同首先加载时光机器数据集

import torch

from torch import nn

from d2l import torch as d2l

batch_size, num_steps = 32, 35

train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

1初始化模型参数

我们需要定义初始化模型参数如前所述参数num_hiddens定义单元数量我们按照标准0.01高斯分布初始化权重并将偏置设置为0.

def get_lstm_params(vocab_size, num_hiddens, device):

num_inputs = num_outputs = vocab_size

def normal(shape):

return torch.randn(size=shape, device = device) * 0.01

def three():

return (normal((num_inputs, num_hiddens)),

normal((num_hiddens, num_hiddens)),

torch.zeros(num_hiddens, device = device))

W_xi,W_hi,b_i = three()

W_xf, W_hf, b_f = three()

W_xo,W_ho,b_o = three()

W_xc,W_hc,b_c = three()

#输出层参数

W_hq = normal((num_hiddens, num_outputs))

b_q = torch.zeros(num_outputs, device =device)

#附加梯度

params = [W_xi,W_hi,b_i, W_xf, W_hf, b_f, W_xo,W_ho,b_o,W_xc,W_hc,b_c]

for param in params

pararm.requeires_grad_(True)

return params

2 定义模型

初始化函数中长短期记忆网络状态需要返回一个额外记忆元0形状因此我们得到以下状态初始化

def init_lstm_state(batch_size, num_hiddens, device):

return (torch.zeros(batch_szie, num_hiddens), device = device)

torch.zeros(batch_size, num_hiddens), device = device

实际模型定义我们前面讨论意义提供3一个额外记忆元只有状态才会传递输出层记忆元Ct直接参与输出计算

def lstm(inputs, state, params):

[W_xi,W_hi,b_i, W_xf, W_hf, b_f, W_xo,W_ho,b_o,W_xc,W_hc,b_c,W-Hq, b_qj] = params

(H,C) = state

outputs=[]

for X in inputs:

I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)

F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)

O = torch.sigmoid((X @ W_xc) + (H @ W_hc) + b_c)

C = F C + I C_tilda

H = O * torch.tanh(C)

Y = (H @ W_hq) + b_q

outputs.append(Y)

return torch.cat(outputs, dim=0), (H,C)

3 训练预测

我们通过实例化8.5节中引入RNNModelScratch来训练一个长短期记忆网络就如我们9.1中所那样

vocab_szie, num_hiddens, device = len(vocab), 256, d2l.try_gpu()

num_epochs, lr = 500, 1

model = d2l.RNNModeScratch(len(vocab), num_hiddens, device, get_lstm_params,

init_lstm_state, lstm)

d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

3.2.3 简洁实现

我们可以直接实例化长短期记忆网络模型高级API封装前文介绍所有配置细节这段代码运行速度很多使用编译好运算符而不是Python代码来处理之前阐述许多细节

num_inputs = vocab_size

lstm_layer = nn.LSTM(num_inputs, num_hiddens)

model = d2l.RNNModel(lstm_layer, len(vocab))

model = model.to(device)

d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

长短期记忆网络典型具有重要状态控制变量自回归模型多年来已经提供了许多变体多层残差连接不同类型正则化然而由于序列长距离依赖性训练长短期记忆网络其他序列模型成本相当后面我们讲述高级替代模型Transformer

小结

长短期记忆网络3类型输入门遗忘门输出门

长短期记忆网络隐藏输出包括状态记忆只有状态传递输出层

记忆完全属于内部信息

长短记忆网络可以缓解梯度消失梯度爆炸

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值