pytorch中已经有很多人实现了convLSTM,但貌似pytorch还没有公布官方版本的convLSTM。以下这一版是比较通用的一个版本,我做注释后放在这里,方便以后查看。
import torch.nn as nn
import torch
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvLSTM cell.
Parameters
----------
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
super(ConvLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2 # 保证在传递过程中 (h,w)不变
self.bias = bias
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim, # i门,f门,o门,g门放在一起计算,然后在split开
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state # 每个timestamp包含两个状态张量:h和c
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis # 把输入张量与h状态张量沿通道维度串联
combined_conv = self.conv(combined) # i门,f门,o门,g门放在一起计算,然后在split开
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g # c状态张量更新
h_next = o * torch.tanh(c_next) # h状态张量更新
return h_next, c_next # 输出当前timestamp的两个状态张量

本文档详细介绍了如何在PyTorch中实现卷积LSTM(ConvLSTM),包括单个ConvLSTMCell的定义及多层ConvLSTM网络的构建。代码示例展示了如何处理输入张量,并提供了初始化隐藏状态的方法。注意,卷积核大小、隐藏通道数和层数需在多层LSTM中保持一致。
最低0.47元/天 解锁文章
8771

被折叠的 条评论
为什么被折叠?



