Pythorch中torch.nn.LSTM()参数详解

本文深入解析PyTorch中的LSTM模块,详细介绍其参数意义、输入输出维度要求及使用方法。通过实例演示如何构建LSTM网络,处理序列数据。

通过源代码中可以看到nn.LSTM继承自nn.RNNBase,其初始化函数定义如下

class RNNBase(Module):
    ...
    def __init__(self, mode, input_size, hidden_size,
                 num_layers=1, bias=True, batch_first=False,
                 dropout=0., bidirectional=False):

我们需要关注的参数以及其含义解释如下:

    input_size – 输入数据的大小,也就是前面例子中每个单词向量的长度
    hidden_size – 隐藏层的大小(即隐藏层节点数量),输出向量的维度等于隐藏节点数
    num_layers – recurrent layer的数量,默认等于1。
    bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
    batch_first – 默认为False,也就是说官方不推荐我们把batch放在第一维,这个CNN有点不同,
   	 此时输入输出的各个维度含义为 (seq_length,batch,feature)。当然如果你想和CNN一样把batch
    	放在第一维,可将该参数设置为True。
    dropout – 如果非0,就在除了最后一层的其它层都插入Dropout层,默认为0。
    bidirectional – If True, becomes a bidirectional LSTM. Default: False

下面介绍一下输入数据的维度要求(batch_first=False):

输入数据需要按如下形式传入 input, (h_0,c_0)

   	input: 输入数据,即上面例子中的一个句子(或者一个batch的句子),
   		其维度形状为 (seq_len, batch, input_size)
        seq_len: 句子长度,即单词数量,这个是需要固定的。当然假如你的一个句子中只有2个单词,
        	但是要求输入10个单词,这个时候可以用torch.nn.utils.rnn.pack_padded_sequence()	
        	或者torch.nn.utils.rnn.pack_sequence()来对句子进行填充或者截断。
        batch:就是你一次传入的句子的数量
        input_size: 每个单词向量的长度,这个必须和你前面定义的网络结构保持一致
    h_0:维度形状为 (num_layers * num_directions, batch, hidden_size):
        结合下图应该比较好理解第一个参数的含义num_layers * num_directions,
         即LSTM的层数乘以方向数量。这个方向数量是由前面介绍的bidirectional决定,
         如果为False,则等于1;反之等于2。
        batch:同上
        hidden_size: 隐藏层节点数
    c_0: 维度形状为 (num_layers * num_directions, batch, hidden_size),各参数含义和h_0类似。

当然,如果你没有传入(h_0, c_0),那么这两个参数会默认设置为0。

output: 维度和输入数据类似,只不过最后的feature部分会有点不同,
	即 (seq_len, batch, num_directions * hidden_size)
	这个输出tensor包含了LSTM模型最后一层每个time step的输出特征,
	比如说LSTM有两层,那么最后输出的是[h10,h11,...,h1l]  ,
	表示第二层LSTM每个time step对应的输出.另外如果前面你对输入数据
	使用了torch.nn.utils.rnn.PackedSequence,那么输出也会做同样的操作编程packed sequence。
	对于unpacked情况,我们可以对输出做如下处理来对方向作分离
	output.view(seq_len, batch, num_directions, hidden_size), 
	其中前向和后向分别用0和1表示Similarly, the directions can be separated in the packed case.

h_n:(num_layers * num_directions, batch, hidden_size), 只会输出最后一个time step的隐状态结果(如下图所示)。
    Like output, the layers can be separated using h_n.view(num_layers, num_directions, batch, hidden_size) and similarly for c_n.
    
c_n :(num_layers * num_directions, batch, hidden_size),只会输出最后个time step的cell状态结果(如下图所示)。

在这里插入图片描述
代码:

rnn = nn.LSTM(10, 20, 2) # 一个单词向量长度为10,隐藏层节点数为20,LSTM有2层
input = torch.randn(5, 3, 10) # 输入数据由3个句子组成,每个句子由5个单词组成,单词向量长度为10
h0 = torch.randn(2, 3, 20) # 2:LSTM层数*方向 3:batch 20: 隐藏层节点数
c0 = torch.randn(2, 3, 20) # 同上
output, (hn, cn) = rnn(input, (h0, c0))

print(output.shape, hn.shape, cn.shape)

>>> torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])
torch.nn.LSTM是PyTorch中用于实现长短期记忆(LSTM)网络的类。LSTM是一种循环神经网络(RNN)的变种,它在处理序列数据时能够更好地捕捉长期依赖关系。torch.nn.LSTM具有以下参数: - input_size: 输入数据的特征维数,通常是词向量的维度。 - hidden_size: LSTM中隐藏层的维度。 - num_layers: 循环神经网络的层数。 - bias: 是否使用偏置,默认为True。 - batch_first: 输入数据的形状是否为(batch_size, seq_length, embedding_dim),默认为False。 - dropout: 用于控制随机失活的概率,默认为0,表示不使用dropout。 - bidirectional: 是否使用双向LSTM,默认为False。 输入数据包括input、(h_0, c_0),其中: - input: 形状为[seq_length, batch_size, input_size]的张量,包含输入序列的特征。 - h_0: 形状为[num_layers * num_directions, batch_size, hidden_size]的张量,包含每个句子的初始隐藏状态。 - c_0: 形状与h_0相同,包含每个句子的初始细胞状态。 输出数据包括output、(h_t, c_t),其中: - output: 形状为[seq_length, batch_size, num_directions * hidden_size]的张量,包含LSTM最后一层的输出特征。 - h_t: 形状为[num_directions * num_layers, batch_size, hidden_size]的张量,包含每个batch中每个句子的最后一个时间步的隐藏状态。 - c_t: 形状与h_t相同,包含每个batch中每个句子的最后一个时间步的细胞状态。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [torch.nn.LSTM](https://blog.csdn.net/weixin_43269419/article/details/121344564)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [关于torch.nn.LSTM()详解(维度,输入,输出)](https://blog.csdn.net/weixin_44201449/article/details/111129248)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值