深度学习
文章平均质量分 54
SigMap
这个作者很懒,什么都没留下…
展开
专栏收录文章
- 默认排序
- 最新发布
- 最早发布
- 最多阅读
- 最少阅读
-
pytorch学习007- -预训练中的权重加载(完全导入,部分导入)
问题预训练后的权重如何导入另一个网络模型?预训练对应的网络模型A与未训练的网络模型结构B不对应?2.1 两个网络模型A和B只有部分对应2.2 集合关系上A属于B2.3 集合关系上B属于A方案PyTorch文档torch.nn.modules.module.Module def load_state_dict(self,state_dict: Dict[str, Tensor] | OrderedDict[str, Tensor],strict: bool = …) -> N原创 2022-03-28 16:16:29 · 12912 阅读 · 7 评论 -
pytorch学习006- -回归问题的损失函数
L1Lossfrom torch import nnloss = nn.L1Loss()首先计算对应位置差值然后求绝对值累加除以总的像素数MSELossfrom torch import nnloss = nn.MSELoss()首先计算对应位置差值然后求平方值累加除以总的像素数SmoothL1Lossfrom torch import nnloss = nn.SmoothL1Loss()首先计算对应位置差值如果差值不小1,则求绝对值后-0.5如果差值小原创 2022-03-06 16:19:52 · 735 阅读 · 0 评论 -
pytorch学习005- -torchsummary的使用
torchsummary的使用使用流程安装导入使用官方说明demo建议查看官方demo -->github使用流程安装pip install torchsummary导入from torchsummary import summary使用# 参数说明summary(your_model, input_size=(channels, H, W))myNet = NET() #NET为自己定义的网络模型data = [(3, 100, 100), (3, 100, 100),原创 2022-01-23 11:13:50 · 2074 阅读 · 0 评论 -
pytorch学习004- -libtorch处理多输入/多输出问题
libtorch处理多输入/多输出问题准备工作多输入问题多输出问题参考文章一般处理多输入/多输出问题时,pytorch中容易处理;但是在libtorch中会出现一些问题。这篇博客为记录用。准备工作首先加载多模型。因为可能需要根据不同情况调用多个模型,所以这里预先声明一个modelsstd::vector<torch::jit::script::Module> models;然后使用try/catch来加载模型try { models.push_bac原创 2022-01-16 17:19:26 · 3491 阅读 · 0 评论 -
pytorch学习003- -如何导出c++中可用的pytorch模型
在 C++ 中加载 TorchScript 模型推荐阅读官方文档:如何保存模型如何在c++中加载模型pytorch的c++ apitorchScript文档以下内容基于官方文档写一些注释~~将pytorch模型转换为torch脚本pytorch模型从python到c++是通过torchScript实现的跟踪是指通过示例输入进行一次推理并获取模型的结构和参数,并记录下这些输入在模型中的流转。注释是向模型中添加显式注释,通知编译器对模型代码进行解析。通过跟踪转换为torch脚本并保存o原创 2022-01-13 17:02:59 · 1357 阅读 · 0 评论 -
pytorch学习002- -debug(load_state_dict() missing 1 required positional argument: ‘state_dict‘)
Debugsavept.py文件用来将gpu上训练的模型转换为cpu上推理可用的pt文件以下为部分代码:import torchfrom model import NETprint(torch.__version__) # 1.10.1+cpumodel = NETdevice = torch.device('cpu')state_dict = torch.load(PATH, map_location=device)model.load_state_dict(state_dic原创 2022-01-13 16:18:14 · 3460 阅读 · 3 评论 -
pytorch学习001- -如何保存模型
保存和加载模型只保存模型的参数保存torch.save(model.state_dict(),'xxx.pth')加载model = net() #首先要先定义网络模型state_dict = torch.load('xxx.pth') # 读取pth文件中的参数model.load_state_dict(state_dict['model']) #将参数导入模型这种方法操作比较麻烦,但是比较节省内存。official exampleclass MyModule(torch.nn.M原创 2022-01-13 15:23:27 · 524 阅读 · 0 评论
分享