pytorch 获取参数量代码片段

这段代码用于计算一个模型的总参数数量和训练参数数量。通过迭代模型的参数,`sum(p.numel() for p in model.parameters())`求和得到全部参数数,而`sum(p.numel() for p in model.parameters() if p.requires_grad)`则计算需要梯度更新的训练参数数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

 

### 计算深度学习模型参数量 为了计算深度学习模型中的参数数量,可以遍历模型的所有可训练参数并累加其数目。下面展示了两种常见的方式:一种适用于PyTorch框架下的模型,另一种则针对TensorFlow/Keras。 #### PyTorch 实现方式 对于基于PyTorch构建的神经网络而言,可以通过访问`model.parameters()`方法获取所有的权重张量对象,进而统计总的参数个数: ```python def count_parameters(model): total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Total number of parameters : {total_params}') print(f'Number of trainable parameters : {trainable_params}') # 假设NeuralNetwork类已定义好 input_dim, hidden_dim1, hidden_dim2, output_dim = 784, 500, 250, 10 net = NeuralNetwork(input_dim=input_dim, hidden_dim1=hidden_dim1, hidden_dim2=hidden_dim2, output_dim=output_dim) count_parameters(net)[^3] ``` 此代码片段会输出整个网络结构中所有参数以及仅限于那些参与梯度更新过程(即可训练)部分的具体数值。 #### TensorFlow / Keras 实现方式 当采用TensorFlow或Keras搭建模型时,则可以直接利用内置属性`.count_params()`来快速得到结果: ```python import tensorflow as tf class SimpleModel(tf.keras.Model): def __init__(self): super(SimpleModel, self).__init__() self.dense1 = tf.keras.layers.Dense(64, activation='relu') self.dense2 = tf.keras.layers.Dense(10) def call(self, inputs): x = self.dense1(inputs) return self.dense2(x) simple_model = SimpleModel() print("Total params:", simple_model.count_params())[^1] ``` 这两种方案都能有效地帮助理解所设计架构内部各组件之间的复杂程度及其潜在影响因子大小。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值