torch.nn.LSTM()详解

本文详细介绍了PyTorch中LSTM模块的使用,包括参数解释、输入输出格式以及实例演示。通过理解input_size、hidden_size、num_layers等参数,可以更好地掌握LSTM在网络中的配置。同时,文章还强调了batch_first参数的影响,并展示了如何通过Embedding层预处理输入数据,以及LSTM输出中h_n和c_n的含义。
部署运行你感兴趣的模型镜像

torch.nn.LSTM()详解

输入的参数列表包括:

  • input_size 输入数据的特征维数,通常就是embedding_dim(词向量的维度)
  • hidden_size LSTM中隐层的维度
  • num_layers 循环神经网络的层数
  • bias 用不用偏置,default=True
  • batch_first 这个要注意,通常我们输入的数据shape=(batch_size,seq_length,embedding_dim),而batch_first默认是False,所以我们的输入数据最好送进LSTM之前将batch_size与seq_length这两个维度调换
  • dropout 默认是0,代表不用dropout
  • bidirectional默认是false,代表不用双向LSTM

输入数据包括input,(h_0,c_0):

  • input就是shape=(seq_length,batch_size,input_size)的张量
  • h_0是shape=(num_layers*num_directions,batch_size,hidden_size)的张量,它包含了在当前这个batch_size中每个句子的初始隐藏状态。其中num_layers就是LSTM的层数。如果bidirectional=True,num_directions=2,否则就是1,表示只有一个方向。
  • c_0和h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始细胞状态。h_0,c_0如果不提供,那么默认是0

输出数据包括output,(h_n,c_n):

  • output的shape=(seq_length,batch_size,num_directions*hidden_size),
    它包含的是LSTM的最后一时间步的输出特征(h_t),t是batch_size中每个句子的长度。
  • h_n.shape==(num_directions * num_layers,batch,hidden_size)
  • c_n.shape==h_n.shape
  • h_n包含的是句子的最后一个单词(也就是最后一个时间步)的隐藏状态,c_n包含的是句子的最后一个单词的细胞状态,所以它们都与句子的长度seq_length无关
  • output[-1]与h_n是相等的,因为output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态,注意LSTM中的隐藏状态其实就是输出,cell state细胞状态才是LSTM中一直隐藏的,记录着信息
import torch
batch_size=3
hidden_size=5
embedding_dim=6
seq_length=4
num_layers=1
num_directions=1
vocab_size=20
import numpy as np
input_data=np.random.uniform(0,19,size=(batch_size,seq_length))
input_data=torch.from_numpy(input_data).long()
embedding_layer=torch.nn.Embedding(vocab_size,embedding_dim)
lstm_layer=torch.nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,num_layers=num_layers,
                        bias=True,batch_first=False,dropout=0.5,bidirectional=False)
lstm_input=embedding_layer(input_data)
assert lstm_input.shape==(batch_size,seq_length,embedding_dim)
lstm_input.transpose_(1,0)
assert lstm_input.shape==(seq_length,batch_size,embedding_dim)
output,(h_n,c_n)=lstm_layer(lstm_input)
assert output.shape==(seq_length,batch_size,hidden_size)
assert h_n.shape==c_n.shape==(num_layers*num_directions,batch_size,hidden_size)

ers*num_directions,batch_size,hidden_size)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RRwJIyNY-1654521682679)(C:\Users\sunanpeng\AppData\Roaming\Typora\typora-user-images\image-20220527192414386.png)]

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值