paddlepaddle的LSTM如何写到Sequential中。

为了提高代码复用,需要以子网的形式生成网络成。paddlepaddle生成子网的方法有两种,见官方文档。但类似LSTM这样的网络,没法直接放入Sequential网络,因为LSTM网络的输出是tuple(官方文档),不能直接传递给下一层网络处理。

这里提供一种解决的思路。

* 新建类,继承LSTM,重新forward函数。
* 在新写的forward函数中调用父类LSTM的forward方法。
* 对父类forward的输出拦截输,进行需要的处理,然后再输出。

例如:

import paddle

#新建自定义类,继承LSTM
class MyLSTM(paddle.nn.LSTM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, inputs):
        output,_ = super().forward(inputs) #调用父类的forward
        return output[:,-1,:] #去输出的最后序列,传给下一层。因为后面一层不是LSTM层。

class T(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.s = paddle.nn.Sequential(
            MyLSTM(5,32), 
            paddle.nn.BatchNorm(32))
    
    def forward(self, inputs):
        return self.s(inputs)

model = T()

paddle.summary(model, (1,20,5))```

---------------------------------------------------------------------------
Layer (type
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值