训练框架
文章目录
本文内容以pytorch为例
1. 模型网络结构
自定义网络模型继承‘nn.Module’,实现模型的参数的初始化与前向传播;自定义网络模型可以添加权重初始化、网络模块组合等其他方法
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
2. 数据读取与数据加载
数据集类的基础方法
- dataset:
需要包含数据迭代方法
def __getitem__(self, index):
image, target = self.list(index)
return image, target
利用torch.utils.data.DataLoader封装后,用于迭代遍历数据元素;
数据长度方法
def __len__(self):
return self.dataset_size
数据加载
- dataloader:
对数据集类(通常实现了 getitem 和 len 方法)时,
你可以使用 DataLoader 来轻松地进行批量加载、打乱数据、并行加载以及多进程数据加载。
collate_fn:将字典或数组数据流进行拆分,拆分为图像、label、边界框、文字编码等不同类型数据与模型的输入与输出相匹配
2.1Dataloater参数
参数:
- dataset (Dataset): 加载数据的数据集。
- batch_size (int, 可选): 每批加载的样本数量(默认:1)。
- shuffle (bool, 可选): 设置为 True 以在每个 epoch 重新洗牌数据(默认:False)。
- sampler (Sampler 或 Iterable, 可选): 定义从数据集中抽取样本的策略。可以是任何实现了 len 的 Iterable。如果指定了 sampler,则不能指定 :attr:shuffle。
- batch_sampler (Sampler 或 Iterable, 可选): 与 :attr:sampler 类似,但一次返回一批索引。与 :attr:batch_size, :attr:shuffle, :attr:sampler, 和 :attr:drop_last 互斥。
num_workers (int, 可选): 数据加载使用的子进程数量。0 表示数据将在主进程中加载(默认:0)。 - collate_fn (Callable, 可选): 将样本列表合并以形成 Tensor(s) 的 mini-batch。在使用 map-style 数据集的批量加载时使用。
- pin_memory (bool, 可选): 如果设置为 True,则数据加载器将在返回它们之前将 Tensors 复制到设备/CUDA 固定内存中。如果你的数据元素是自定义类型,或者你的 :attr:collate_fn 返回的批次是自定义类型,请参见下面的例子。
- drop_last (bool, 可选): 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批量大小整除。如果设置为 False 并且数据集大小不能被批量大小整除,则最后一个批次会较小(默认:False)。
- timeout (numeric, 可选): 如果为正数,这是从工作进程收集一个批次的超时值。应始终为非负数(默认:0)。
- worker_init_fn (Callable, 可选): 如果不是 None,这将在每个工作进程子进程上调用,输入为工作进程 id(一个在 [0, num_workers - 1] 中的 int),在设置种子和数据加载之前。(默认:None)
- generator (torch.Generator, 可选): 如果不是 None,这个 RNG 将被 RandomSampler 用来生成随机索引,并被多进程用来为工作进程生成 base_seed。(默认:None)
- prefetch_factor (int, 可选,关键字参数): 每个工作进程预先加载的批次数量。2 意味着将有总共 2*num_workers 个批次被预先加载。(默认值取决于 num_workers 的设定值。如果 num_workers=0,默认是 None。否则如果 num_workers>0,默认是 2)。
- persistent_workers (bool, 可选): 如果设置为 True,则数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作进程的 Dataset 实例存活。(默认:False)。
- pin_memory_device (str, 可选): 如果将 pin_memory 设置为 true,则数据加载器将在返回它们之前将 Tensors 复制到设备固定内存中。
2.2 collate_fn
class CollateFunc(object):
def __call__(self, batch):
targets = []
images = []
for sample in batch:
image = sample[0]
target = sample[1]
images.append(image)
targets.append(target)
images = torch.stack(images, 0