在 PyTorch 中,可以通过遍历模型的参数(parameters)并调用每个参数张量的 .numel()
方法来计算模型中所有参数的数量。下面给出一种常见的做法:
total_params = sum(p.numel() for p in model.parameters())
print("模型的参数总量:", total_params)
#或者
sum = 0
for p in model.parameters():
sum = sum + p.numel()
return sum
统一计算参数和GFLOPS代码:
import torch
from thop import profile
from models.lbunet import LBUNet
tensor = torch.randn(1, 3, 256, 256)
model = LBUNet(num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64])
gt_pre, edge, out = model(tensor)
flops, params = profile(model, inputs=(tensor, ))
print(f"FLOPs: {flops / 1e9:.4f} GFLOPs") # 保留小数点后四位
print(f"Params: {params / 1e6:.4f} M") # 保留小数点后四位