如何计算lstm网络的复杂度 乘法次数 flops(未完成)

本文介绍了一种计算LSTM网络复杂度的方法:通过计算乘法运算次数来评估。详细解析了LSTM单元内部的乘法操作,并使用PyTorch OpCounter库进行实际的计算演示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

计算算法复杂度的一种方法是计算前向算法的乘法次数,因此在这篇文章中计算复杂度的方法是计算lstm网络的乘法运算次数

首先要弄清楚lstm的cell中的乘法次数

cell有4个输入,一个输出

有三个门,input gate控制数据输不输进来,不输入就输入0

forget gate控制保存单元memory cell更不更新,不更新就维持原状

output gate控制计算的值输不输出,不输出就输出0

举例

x是input y是output

lstm的cell有4个input,就是矢量x+bias,乘上weight,求和

以信号处理为例,batch为1,全部省略,此时的x大小[1,Nv] 

接下来是用别人写好的库计算,用的是pytorch OpCounter

GitHub - Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model.

按照github上的提示

pip install thop

然后就可以运行下面的例子了

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值