# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
这段代码用于计算一个模型的总参数数量和训练参数数量。通过迭代模型的参数,`sum(p.numel() for p in model.parameters())`求和得到全部参数数,而`sum(p.numel() for p in model.parameters() if p.requires_grad)`则计算需要梯度更新的训练参数数。
695

被折叠的 条评论
为什么被折叠?



