pytorch保存的.pth模型文件

pytorch的.pth模型文件保存
博客链接指向关于pytorch保存的.pth模型文件的内容,涉及pytorch在模型文件保存方面的相关信息技术知识。
你可以使用 PyTorch 加载 `.pth` 模型文件并查看模型的参数量(即模型中可训练参数和总参数的数量)。以下是完整的实现方法: ```python import torch from torch import nn def count_parameters(model): total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return total_params, trainable_params # 示例:加载 .pth 模型文件 model_path = 'your_model.pth' # 替换为你的模型路径 # 方法1: 如果保存的是整个模型(包括结构) try: model = torch.load(model_path, map_location=torch.device('cpu')) model.eval() total, trainable = count_parameters(model) print(f"模型总参数量: {total:,}") print(f"可训练参数量: {trainable:,}") except Exception as e: print("无法直接加载完整模型,尝试加载 state_dict...") # 方法2: 如果保存的是 state_dict(更常见) try: state_dict = torch.load(model_path, map_location=torch.device('cpu')) # 需要先定义对应的模型结构(例如使用 ResNet、自定义网络等) # 下面以一个简单的自定义模型为例: class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.fc = nn.Linear(64, 10) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = x.view(x.size(0), -1) x = self.fc(x) return x model = SimpleModel() model.load_state_dict(state_dict) model.eval() total, trainable = count_parameters(model) print(f"模型总参数量: {total:,}") print(f"可训练参数量: {trainable:,}") except Exception as e: print("加载失败,请检查模型文件格式或提供正确的模型结构。") print(e) ``` ### 解释: - `.pth` 文件可能保存了整个模型对象,也可能只保存了 `state_dict`(推荐方式)。 - 使用 `sum(p.numel() for p in model.parameters())` 可以统计所有参数数量。 - `numel()` 返回张量中元素的总数。 - `p.requires_grad` 判断是否参与梯度更新(即可训练参数)。 > ⚠️ 注意:如果 `.pth` 文件仅包含 `state_dict`,你必须事先知道模型的结构,并实例化相同的模型类才能加载成功。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值