为了提高代码复用,需要以子网的形式生成网络成。paddlepaddle生成子网的方法有两种,见官方文档。但类似LSTM这样的网络,没法直接放入Sequential网络,因为LSTM网络的输出是tuple(官方文档),不能直接传递给下一层网络处理。
这里提供一种解决的思路。
* 新建类,继承LSTM,重新forward函数。
* 在新写的forward函数中调用父类LSTM的forward方法。
* 对父类forward的输出拦截输,进行需要的处理,然后再输出。
例如:
import paddle
#新建自定义类,继承LSTM
class MyLSTM(paddle.nn.LSTM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inputs):
output,_ = super().forward(inputs) #调用父类的forward
return output[:,-1,:] #去输出的最后序列,传给下一层。因为后面一层不是LSTM层。
class T(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.s = paddle.nn.Sequential(
MyLSTM(5,32),
paddle.nn.BatchNorm(32))
def forward(self, inputs):
return self.s(inputs)
model = T()
paddle.summary(model, (1,20,5))```
---------------------------------------------------------------------------
Layer (type