LSTM理解与pytorch使用

本文深入解析了LSTM的结构原理,阐述其在处理时序数据如视频和句子方面的应用,并通过实例展示了Pytorch中LSTM的使用方法及参数配置,包括输入输出格式、隐藏层信息获取等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

引言

LSTM应该说是每一个做机器学习的人都绕不开的东西,它的结构看起来复杂,但是充分体现着人脑在记忆过程中的特征,下面本文将介绍一下LSTM的结构以及pytorch的用法。

LSTM结构

总体结构

在这里插入图片描述
首先,LSTM主要用来处理带有时序信息的数据,包括视频、句子,它将人脑的对于不同time step的记忆过程理解为一连串的cell分别对不同的时刻输入信息的处理。

详细结构

一个典型的 LSTM 结构可以分别从输入、处理和输出三个角度来解析:

  • 输入: 输入包含三个部分,分别是 cell 的信息𝐶t-1,它代表历史的记忆细胞(cell)状态信息的汇总;隐藏层的信息ht-1, 它是提取到的上个时刻的特征信息; 以及当前的输入𝑥t
  • 处理: 处理部分主要是由遗忘门、输入门、输出门组成。遗忘门由当前的输入和隐藏层信息控制对于历史的 cell 信息的遗忘程度;输入门是决定当前的输入和隐藏
    层信息的利用程度;输出门是由当前的 cell 状态和输入决定输出。
  • 输出: 分别是当前的 cell 状态𝐶’和当前的隐藏层信息h’。

遗忘门
在这里插入图片描述
输入门
在这里插入图片描述
细胞状态更新
在这里插入图片描述
输出门
在这里插入图片描述

Pytorch用法

参数介绍

class torch.nn.LSTM(*args, **kwargs)

参数

  • input_size:输入的特征维度
  • hidden_size:隐藏层的特征维度(即输出的特征维度)
  • num_layers:LSTM隐层的层数,默认为1
  • bias:False则bih=0和bhh=0. 默认为True
  • batch_first:True则输入输出的数据格式为 (batch, seq, feature)
  • dropout:除最后一层,每一层的输出都进行dropout,默认为: 0
  • bidirectional:True则为双向LSTM,默认为False

输入:input, (h0, c0)
输入数据格式
input(seq_len, batch, input_size)
seq_len可以理解为一个视频有多少帧或者一个句子有多少单词,input_size就是一个帧或者一个单词可以用多少维的特征向量表示。
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)

输出:output, (hn, cn)
输出数据格式
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)

使用实例

rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)#(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10)#(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size)
c0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size)
output, (hn, cn) = rnn(input, (h0, c0))

output.shape #(seq_len, batch, output_size2)
torch.Size([5, 3, 40])
hn.shape #(num_layers
2, batch, output_size)
torch.Size([2, 3, 20])

获取中间各层的隐藏层信息

lstm = nn.LSTM(3, 3) 
inputs = [torch.randn(1, 3) for _ in range(5)]
# 这里的inputs的大小是一个含有5个1*3的tensor的列表,可以理解为一个5*1*3维的输入,其中5是seq_len,1是batch_size,3是input_size的大小

# 初始化隐藏状态
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))

for i in inputs:
    # 将序列的元素逐个输入到LSTM,经过每步操作,hidden 的值包含了隐藏状态的信息
    out, hidden = lstm(i.view(1, 1, -1), hidden)

关于变长输入

由于在视觉里变长输入的情况较少,这里只给出几个链接:
1.pytorch中如何处理RNN输入变长序列padding

 https://zhuanlan.zhihu.com/p/34418001

2.教你几招搞定 LSTMs 的独门绝技

https://zhuanlan.zhihu.com/p/40391002
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值