tensorflow 计算参数量

本文介绍了一种使用TensorFlow计算所有可训练变量尺寸的方法,通过numpy的sum和prod函数结合TensorFlow的trainable_variables获取模型参数总量。
部署运行你感兴趣的模型镜像

参考自https://blog.youkuaiyun.com/feynman233/article/details/79187304,记录下方便之后查看。

print np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])

转载于:https://my.oschina.net/u/4017456/blog/2906986

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

### 如何在 TensorFlow计算深度学习模型的总参数量TensorFlow 中,可以通过内置的方法来获取模型中的总参数量以及区分可训练与非可训练参数的数量。以下是具体的实现方式: #### 使用 `model.summary()` 方法 `model.summary()` 是一种简单有效的方式来查看模型架构及其参数统计信息。该方法会打印每一层的名称、输出形状以及参数数量,并汇总整个模型的总参数量。 ```python import tensorflow as tf # 假设已经定义了一个模型 model model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)), tf.keras.layers.Dense(10, activation='softmax') ]) # 打印模型摘要 model.summary() ``` 通过上述代码运行后,可以清晰看到每层的参数数量以及总的 **可训练参数** 和 **非可训练参数** 数量[^1]。 --- #### 访问模型参数属性 如果需要更灵活的方式访问这些数值,则可以直接调用模型对象的相关属性: - `model.trainable_variables`: 返回所有可训练变量列表。 - `model.non_trainable_variables`: 返回所有不可训练变量列表。 下面是一个示例代码展示如何手动计算参数总数: ```python total_params = sum([tf.size(var).numpy() for var in model.variables]) trainable_params = sum([tf.size(var).numpy() for var in model.trainable_variables]) non_trainable_params = total_params - trainable_params print(f"Total Parameters: {total_params}") print(f"Trainable Parameters: {trainable_params}") print(f"Non-trainable Parameters: {non_trainable_params}") ``` 此代码片段利用了 TensorFlow 提供的张量操作函数 `tf.size()` 来逐个累加各变量所占存储空间大小,从而得出精确的结果[^2]。 --- 需要注意的是,在某些特殊情况下(例如残差网络 ResNet 或其他复杂结构),可能存在部分层未被正确初始化或者配置错误的情况,这可能导致报告出来的 shape 不完整等问题。此时应仔细检查模型构建逻辑并修正相应缺陷[^3]。 --- #### 特殊情况处理建议 对于卷积神经网络 (CNN),尤其是像 ResNet 这样的深层架构设计时,由于存在 shortcut connection 需要保证输入输出维度匹配,因此可能引入额外的 \(1 \times 1\) 卷积用于调整 channel size 。这种机制虽然增加了少量参数开销,但却有助于提升整体表现力。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值