ResNet18 训练代码解析
训练一个 ResNet-18 来分类 11 类食物图片,包含 数据加载、模型初始化、超参数设置和训练。它还支持 半监督学习,使用无标签数据提升模型性能。
1. 代码结构
代码由以下几个部分组成:
- 导入库
- 随机种子设置
- 超参数设置
- 数据加载
- 模型初始化
- 优化器设置
- 训练参数字典
- 执行训练
2. 代码解析
(1)导入依赖库
import random
import torch
import torch.nn as nn
import numpy as np
import os
from model_utils.model import initialize_model
from model_utils.train import train_val
from model_utils.data import getDataLoader
torch
&torch.nn
:PyTorch 深度学习框架。random
&numpy
:用于设置随机种子,确保实验可复现。os
:用于操作系统路径、环境变量等。- 自定义模块:
initialize_model
:加载指定的神经网络(如 ResNet-18)。train_val
:执行模型训练和验证。getDataLoader
:加载数据。
(2)设置随机种子
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(