LSTM 简介
- 官方文档:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
- h_n:最后一个时间步的输出,即 h_n = output[:, -1, :](一般可以直接输入到后续的全连接层,在 Keras 中通过设置参数 return_sequences=False 获得)
- c_n:最后一个时间步 LSTM cell 的状态(一般用不到)
实例
-
实例:根据红框可以直观看出,h_n 是最后一个时间步的输出,即是 h_n = output[:, -1, :],如何还是无法直观理解,直接看如下截图,对照代码可以非常容易看出它们的关系

-
实例代码:
>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.LSTM(input_size=2, hidden_size=3, batch_first=True)
>>> input = torch.randn(5,4,2)
>>> h0 = torch.randn(1, 5, 3)
>>> c0 = torch.randn(1, 5, 3)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[-0.1046, -0.0316, -0.2261],
[ 0.0702, 0.0756, -0.2856],
[ 0.1146, 0.0666,

最低0.47元/天 解锁文章
1523





