Pytorch 不训练(frozen)一些神经网络层的方法

Pytorch 不训练(frozen)一些神经网络层的方法

我们在做深度学习的时候经常会使用预训练的模型。很多情况下,加载进来模型是为了完成其他任务,在这种情况下,加载模型的一部分是不需要再训练的。那么我们就需要forozen这些神经网络层。

固定某些层训练,就是将tensor的requires_grad设为False。
此外,一定要记住,我们还需要在optim优化器中再将这些参数过滤掉!
下面见代码:

device = torch.device("cuda" )

    #Try to load models

model = DGCNN(args)
print(str(model))
model = model.to(device)


    
save_model = torch.load('model.t7')
model_dict =  model.state_dict()
### 如何使用已训练神经网络模型进行预测 #### 使用Simulink调用MATLAB中的神经网络模型 为了在Simulink环境中调用由MATLAB训练完成的神经网络模型,需先确保模型已经被妥善保存为`.mat`文件或其他兼容格式。之后可以在Simulink项目里引入这些预训练好的权重参数来构建相应的功能模块[^2]。 ```matlab % MATLAB代码片段用于加载预先训练神经网络模型 load('trained_network.mat'); % 加载包含训练完毕后的网络结构与权值的数据文件 ``` 接着,在Simulink中创建新的仿真环境时,可以通过S-Function Block或者MATLAB Function block将上述导入的模型集成进来,从而实现在实时控制系统内执行诸如回归预测、分类识别等功能[^3]。 #### Python环境下基于PyTorch框架的应用实例 对于采用Python编程语言开发的应用场景而言,特别是当选择了像PyTorch这样的深度学习库来进行建模工作后,则可以直接借助其内置API轻松读取之前存储下来的.pth/.pt格式的模型档案,并据此开展推断操作[^4]。 ```python import torch from torchvision import models, transforms # 假设我们有一个名为'model_best.pth.tar' 的预训练模型存档 checkpoint = torch.load('model_best.pth.tar') pretrained_model = checkpoint['state_dict'] # 创建相同架构的新实例并加载状态字典 model = models.resnet18(pretrained=False) model.load_state_dict(pretrained_model) def predict(image_path): preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_image = Image.open(image_path) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model with torch.no_grad(): output = model(input_batch) probabilities = torch.nn.functional.softmax(output[0], dim=0) top_prob, top_catid = torch.topk(probabilities, 1) return top_prob.item(), top_catid.item() ``` #### C/C++应用程序对接Python/TensorFlow模型案例 如果目标平台是以C/C++为主要开发工具的话,那么还可以考虑通过嵌入式解释器或者其他形式跨语言调用来访问那些原本是在Python下定义和优化过的AI组件。这里给出了一种可能的技术路线图——即利用SWIG等工具封装好必要的接口函数,使得最终产品既能够保持高效运行又能充分利用现有资源[^5]。 ```c // 部分简化版C代码示例展示如何初始化并调用Python端口导出的服务 #include "tensorflowc_py_models.h" int main() { char* image_file = "./test.jpg"; float probability; int class_id; init_python_env(); // 初始化Python环境 load_saved_model("path_to_frozen_graph.pb"); // 加载冻结图或检查点路径下的模型 run_inference_on_image(image_file, &probability, &class_id); // 执行推理获取结果 } ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值