pytorch-15-Train_Val_Test&交叉验证

本文详细探讨了使用PyTorch进行深度学习模型的训练、验证和测试流程,包括数据集划分、交叉验证的概念及其在PyTorch中的实现,旨在提升模型的泛化能力。

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_db = datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size, shuffle=True)

test_db = datasets.MNIST('./data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.D
d:\ab\mmsegmentation\mmseg\models\backbones\resnet.py:431: UserWarning: DeprecationWarning: pretrained is a deprecated, please use "init_cfg" instead warnings.warn('DeprecationWarning: pretrained is a deprecated, ' d:\ab\mmsegmentation\mmseg\models\losses\cross_entropy_loss.py:251: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``. warnings.warn( 06/11 18:03:55 - mmengine - INFO - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used. d:\ab\mmsegmentation\mmseg\engine\hooks\visualization_hook.py:60: UserWarning: The draw is False, it means that the hook for visualization will not take effect. The results will NOT be visualized or stored. warnings.warn('The draw is False, it means that the ' 06/11 18:03:55 - mmengine - INFO - Hooks will be executed in the following order: before_run: (VERY_HIGH ) RuntimeInfoHook (BELOW_NORMAL) LoggerHook -------------------- before_train: (VERY_HIGH ) RuntimeInfoHook (NORMAL ) IterTimerHook (VERY_LOW ) CheckpointHook -------------------- before_train_epoch: (VERY_HIGH ) RuntimeInfoHook (NORMAL ) IterTimerHook (NORMAL ) DistSamplerSeedHook -------------------- before_train_iter: (VERY_HIGH ) RuntimeInfoHook (NORMAL ) IterTimerHook -------------------- after_train_iter: (VERY_HIGH ) RuntimeInfoHook (NORMAL ) IterTimerHook (BELOW_NORMAL) LoggerHook (LOW ) ParamSchedulerHook (VERY_LOW ) CheckpointHook -------------------- after_train_epoch: (NORMAL ) IterTimerHook (LOW ) ParamSchedulerHook (VERY_LOW ) CheckpointHook -------------------- before_val: (VERY_HIGH ) RuntimeInfoHook -------------------- before_val_epoch: (NORMAL ) IterTimerHook -------------------- before_val_iter: (NORMAL ) IterTimerHook -------------------- after_val_iter: (NORMAL ) IterTimerHook (NORMAL ) SegVisualizationHook (BELOW_NORMAL) LoggerHook -------------------- after_val_epoch: (VERY_HIGH ) RuntimeInfoHook (NORMAL ) IterTimerHook (BELOW_NORMAL) LoggerHook (LOW ) ParamSchedulerHook (VERY_LOW ) CheckpointHook -------------------- after_val: (VERY_HIGH ) RuntimeInfoHook -------------------- after_train: (VERY_HIGH ) RuntimeInfoHook (VERY_LOW ) CheckpointHook -------------------- before_test: (VERY_HIGH ) RuntimeInfoHook -------------------- before_test_epoch: (NORMAL ) IterTimerHook -------------------- before_test_iter: (NORMAL ) IterTimerHook -------------------- after_test_iter: (NORMAL ) IterTimerHook (NORMAL ) SegVisualizationHook (BELOW_NORMAL) LoggerHook -------------------- after_test_epoch: (VERY_HIGH ) RuntimeInfoHook (NORMAL ) IterTimerHook (BELOW_NORMAL) LoggerHook -------------------- after_test: (VERY_HIGH ) RuntimeInfoHook -------------------- after_run: (BELOW_NORMAL) LoggerHook --------------------这些是什么意思
06-12
### 解释警告信息和日志输出的含义 #### DeprecationWarning 和 UserWarning 的解决方法 在代码运行过程中,`DeprecationWarning` 和 `UserWarning` 是常见的警告类型。`DeprecationWarning` 表示某些功能或参数在未来版本中将被废弃,而 `UserWarning` 则是开发者为了提醒用户某些潜在问题而发出的警告。例如,在 `mmsegmentation` 中,`pretrained` 参数已被弃用,建议使用 `init_cfg` 来替代[^1]。以下是具体的替代方法: ```python # 替代 pretrained 参数的方法 model = dict( type='EncoderDecoder', backbone=dict( type='ResNet', depth=50, init_cfg=dict(type='Pretrained', checkpoint='path/to/checkpoint')), decode_head=dict(...)) ``` #### `cross_entropy_loss` 中 `avg_non_ignore` 参数的作用 在 `mmsegmentation` 的损失函数配置中,`avg_non_ignore` 参数决定了如何处理忽略类别的样本。如果设置为 `True`,则在计算平均损失时会排除忽略类别的样本;如果设置为 `False`,则所有样本都会参与平均计算,包括忽略类别的样本[^2]。以下是一个配置示例: ```python # 配置 cross_entropy_loss loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True) ``` #### `SyncBatchNorm` 自动回退到 `BatchNormXd` 的原因 当使用 `SyncBatchNorm` 时,如果分布式训练环境未正确配置或者 GPU 数量不足,框架可能会自动将 `SyncBatchNorm` 回退到 `BatchNormXd`。这是为了避免因同步操作导致的性能下降或错误[^3]。确保分布式训练环境正确配置可以减少这种回退的发生。 #### `VisualizationHook` 中 `draw=False` 不生效的含义 在 `mmsegmentation` 的 `VisualizationHook` 中,`draw=False` 的设置可能不生效,通常是因为钩子的逻辑未正确处理该参数。具体来说,`draw` 参数控制是否在训练过程中绘制中间结果,但如果钩子未正确实现对 `draw` 参数的检查,则可能导致此行为失效[^4]。以下是一个可能的修复方法: ```python class VisualizationHook: def __init__(self, draw=False): self.draw = draw def after_train_iter(self, runner): if not self.draw: return # 绘制逻辑 ``` #### Hooks 的执行顺序解释 在深度学习框架中,`hooks` 的执行顺序通常由其优先级决定。例如,在 `mmsegmentation` 中,`hooks` 的默认执行顺序为:`LoggerHook` > `EvalHook` > `CheckpointHook` > `VisualizationHook` 等[^5]。可以通过调整 `priority` 参数来改变特定 `hook` 的执行顺序。例如: ```python custom_hooks = [ dict(type='VisualizationHook', priority='LOW'), dict(type='EvalHook', priority='HIGH') ] ``` ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值