torch.nn.Module.eval

本文详细介绍了PyTorch中torch.nn.Module.eval()函数的作用,即设置模块为评估模式。此模式对某些模块如Dropout和BatchNormalization的行为有显著影响,提供了在训练和评估模式下模块行为差异的概述。

torch.nn.Module.eval:

Sets the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected,

这个函数仅对特定的模块有效.有些函数在 training/evaluation  模态下有不同的行为,可以到具体模块的文档去查看在不同模态下有什么不同的行为.

备注:

受"训练模态""评估模态"影响的模块有 "dropout层"及"batch normalization层".其它的暂不知道

 

 

 

 

 

### 如何使用 `torch.nn.Module` 加载预训练模型 在 PyTorch 中,加载和保存模型是一个常见的操作。为了实现这一目标,可以利用 `torch.save` 和 `torch.load` 函数以及 `state_dict` 方法来处理模型的状态字典[^2]。 #### 使用 `torch.nn.Module` 加载预训练模型的关键步骤解析 1. **状态字典 (`state_dict`) 的概念** - 每个 `torch.nn.Module` 都有一个内部存储器称为 `state_dict`,它记录了所有的可学习参数及其对应的张量值。这些参数通常是以键值对的形式存在,其中键是参数的名字,而值则是具体的权重矩阵或其他数据结构[^3]。 2. **保存模型的方法** - 可以通过调用 `model.state_dict()` 获取当前模型的所有参数,并将其传递给 `torch.save(path)` 来持久化到磁盘上。例如: ```python torch.save(model.state_dict(), 'model.pth') ``` 3. **加载已保存的模型** - 当需要恢复之前训练好的模型时,先实例化相同的网络架构对象(即重新创建一个具有相同配置的类),再调用其方法 `.load_state_dict(torch.load('path'))` 将先前保存下来的参数赋值回去即可完成整个过程。下面给出具体代码示例: ```python model = Model() # 假设这是你的自定义模型类名 model.load_state_dict(torch.load('model.pth')) model.eval() # 设置为评估模式 ``` 4. **注意事项** - 如果是在不同的设备之间迁移模型,则可能还需要考虑显卡与CPU之间的切换问题,在这种情况下可以在加载的时候指定映射位置,比如这样写法会强制把所有tensor都搬到GPU上去运行如果可用的话: ```python device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load('model.pth', map_location=device)) ``` 以上就是基于 `torch.nn.Module` 类型下的标准流程介绍如何正确地存档并提取已经过良好调整后的神经网络权重信息以便后续重复利用或者部署上线等工作场景下应用自如[^4]。 ### 示例代码展示完整的保存与加载逻辑 ```python import torch from torch import nn # 定义简单的卷积神经网络作为例子 class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) return x # 创建模型实例并随机初始化一些输入测试一下前向传播功能正常与否 net = SimpleCNN() example_input = torch.randn((1, 1, 28, 28)) output_before_save = net(example_input).detach().numpy() print("Output before saving:", output_before_save) PATH = './cnn_net.pth' torch.save(net.state_dict(), PATH) loaded_model = SimpleCNN() loaded_model.load_state_dict(torch.load(PATH)) loaded_model.eval() with torch.no_grad(): loaded_output = loaded_model(example_input).numpy() print("Loaded Output after loading from disk:", loaded_output) assert np.allclose(output_before_save, loaded_output), "The outputs do not match!" ``` 上述脚本展示了从构建简单 CNN 到实际执行前后对比验证一致性的全过程演示。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值