convolutional LSTM(convLSTM)的pytorch版本代码实现

这篇是自己初学神经网络的时候,写着当笔记用的,我看有人说调用不了,大家酌情参考吧,我现在也不搞lstm了,细节不太记得了
要是有不同意见可以探讨,懒得和我探讨那就是您对,麻烦不要上来就喷人,谢谢您了

convolutional LSTM(convLSTM)是《Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting》一文提出的,用于降水预测。这一网络结构,既考虑了输入之间的空间关联,也考虑了时序信息,因此,也被用于视频分析。
github上已经有了许多个convLSTM的pytorch实现,这里选择Convolution_LSTM_pytorch进行调试运行。
文件中定义了ConvLSTMConvLSTMCell两个类,并给出了一段调用代码。


ConvLSTM

包含__init__forward两个函数。

__init__:根据输入参数定义一个多层的convLSTM

    def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]):
        super(ConvLSTM, self).__init__()
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.step = step
        self.effective_step = effective_step
        self._all_layers = []
        for i in range(self.num_layers):        # 定义一个多层的convLSTM(即多个convLSTMCell),并存放在_all_layers列表中
            name = 'cell{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
            setattr(self, name, cell)
            self._all_layers.append(cell)

forward:一个多层convLSTM的多时步前向传播

    def forward(self, input):
        internal_state = []
        outputs = []
        for step in range(self.step):       # 在每一个时步进行前向运算
            x = input
            for i in range(self.num_layers):        # 对多层convLSTM中的每一层convLSTMCell,依次进行前向运算
                # all cells are initialized in the first step
                name = 'cell{}'.format(i)
                if step == 0:       # 如果是在第一个时步,则需要调用init_hidden进行convLSTMCell的初始化
                    bsize, _, height, width = x.size()
                    (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i],
                                                             shape=(height, width))
                    internal_state.append((h, c))

                # do forward
                (h, c) = internal_state[i]
                x, new_c = getattr(self, name)(x, h, c)     # 调用convLSTMCell的forward进行前向运算
                internal_state[i] = (x, new_c)
            # only record effective steps
            if step in self.effective_step:
                outputs.append(x)

        return outputs, (x, new_c)

ConvLSTMCell

包含__init__forwardinit_hidden三个函数。

__init__:初始化一个LSTM单元

    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_features = 4

        self.padding = int((kernel_size - 1) / 2)

        self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)

        self.Wci = None
        self.Wcf = None
        self.Wco = None

forward:一个LSTM单元里的前向传播,即convLSTM中最核心的5个公式,输出的ch&cc分表代表current hidden_state & current cell_state

    def forward(self, x, h, c):
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
        cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco)
        ch = co * torch.tanh(cc)
        return ch, cc

init_hidden:convLSTMCell的初始化,返回初始的hidden_state & cell_state

    def init_hidden(self, batch_size, hidden, shape):
        if self.Wci is None:
            self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
            self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
            self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
        else:
            assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
            assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
        return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])),
                Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])))

调用

if __name__ == '__main__':
    # gradient check

    # 定义一个5层的convLSTM
    convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5,
                        effective_step=[4])
    loss_fn = torch.nn.MSELoss()

    input = Variable(torch.randn(1, 512, 64, 32))
    target = Variable(torch.randn(1, 32, 64, 32)).double()

    output = convlstm(input)
    output = output[0][0].double()
    res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)
    print(res)

如果需要在其他py文件中调用此模块,直接导入即可

### 如何在 PyTorch实现 ConvLSTM ConvLSTM 是一种结合卷积操作和 LSTM 的神经网络结构,广泛应用于视频处理、序列预测等领域。以下是关于如何使用 PyTorch 实现 ConvLSTM 的详细说明。 #### 1. ConvLSTM 基本原理 ConvLSTM 将传统的 LSTM 单元中的全连接层替换为卷积层,使得模型能够直接作用于空间数据(如图像)。其核心思想是在时间维度上捕捉依赖关系的同时保留输入的空间特征[^4]。 #### 2. ConvLSTM实现细节 为了构建 ConvLSTM,在 PyTorch 中可以定义一个新的 `nn.Module` 类来表示单个 ConvLSTM 单元。该单元通常包括以下几个部分: - 输入门 (Input Gate) - 遗忘门 (Forget Gate) - 输出门 (Output Gate) - 细胞状态更新机制 这些组件通过卷积运算完成参数学习过程。 #### 3. 示例代码 下面是一个简单的 ConvLSTM 实现: ```python import torch import torch.nn as nn class ConvLSTMCell(nn.Module): def __init__(self, input_dim, hidden_dim, kernel_size, bias=True): """ Initialize ConvLSTM cell. Parameters: input_dim: Number of channels of input tensor. hidden_dim: Number of channels of hidden state. kernel_size: Size of the convolutional kernel. bias: Whether or not to add the bias. """ super(ConvLSTMCell, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.kernel_size = kernel_size self.padding = kernel_size[0] // 2, kernel_size[1] // 2 self.bias = bias self.conv = nn.Conv2d( in_channels=self.input_dim + self.hidden_dim, out_channels=4 * self.hidden_dim, kernel_size=self.kernel_size, padding=self.padding, bias=self.bias) def forward(self, input_tensor, cur_state): h_cur, c_cur = cur_state combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis combined_conv = self.conv(combined) cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) i = torch.sigmoid(cc_i) f = torch.sigmoid(cc_f) o = torch.sigmoid(cc_o) g = torch.tanh(cc_g) c_next = f * c_cur + i * g h_next = o * torch.tanh(c_next) return h_next, c_next def init_hidden(self, batch_size, image_size): height, width = image_size return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) ``` 上述代码实现了单一的 ConvLSTM Cell。如果需要堆叠多个 ConvLSTM 层,则可以通过扩展此模块进一步开发完整的 ConvLSTM 网络[^5]。 #### 4. 数据加载与训练流程 当使用自定义的数据集时,可借助 PyTorch 提供的 `Dataset` 和 `DataLoader` 来管理数据流。具体做法如下所示: ```python from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data_path, transform=None): ... def __len__(self): ... def __getitem__(self, idx): ... train_dataset = CustomDataset(data_path='path_to_data', transform=transformations) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) ``` 随后即可利用 `train_loader` 进行批量迭代并传入到 ConvLSTM 模型中进行训练[^6]。 --- ###
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值