本文主要是有关convLSTM的pytorch实现代码的理解,原理请移步其他博客。
在pytorch中实现LSTM或者GRU等RNN一般需要重写cell,每个cell中包含某一个时序的计算,也就是以下:
在传统LSTM中,LSTM每次要调用t次cell,t就是时序的总长度,如果是n层LSTM就相当于一共调用了n*t次cell
class ConvLSTMCell(nn.Module):
def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvLSTM cell.
Parameters
----------
input_size: (int, int)
Height and width of input tensor as (height, width).
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
su