[tensorflow] 如何从pb模型文件中获得参数信息 How to obtain parameters information from a tensorflow .pb file?

为了比较模型复杂度,需要获取模型参数数量。当面对.pb文件而非checkpoint时,发现`tf.trainable_variables()`返回为空。解决方案是手动过滤constant节点,通过检查ndim和name来获取trainable_variables。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

因为要和SOTA比较模型的复杂度,我想知道参数数量。但是模型文件不是tensorflow checkpoint,而是pb文件,我发现当导入graph后,tf.trainable_variables()返回空。
Problem setting : I need to compare with state-of-the-arts the model complexity so the model parameter amount is needed. However the provided model file isn’t the ckpt file, but pb file, and the variables returned by tf.trainable_variables() is found empty.

这个回答给出了方法。
This answer gives the solution.

举例:
In my case:

# import graph
with open('spmc_120_160_4x3f.pb', 'rb') as f:
	graph_def = tf.GraphDef()
	graph_def.ParseFromString(f.read())
	output = tf.import_graph_def(graph_def, input_map=
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值