深度学习训练框架——监督学习为例

训练框架


本文内容以pytorch为例

1. 模型网络结构

自定义网络模型继承‘nn.Module’,实现模型的参数的初始化与前向传播;自定义网络模型可以添加权重初始化、网络模块组合等其他方法

        import torch.nn as nn
        import torch.nn.functional as F

        class Model(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(1, 20, 5)
                self.conv2 = nn.Conv2d(20, 20, 5)

            def forward(self, x):
                x = F.relu(self.conv1(x))
                return F.relu(self.conv2(x))

2. 数据读取与数据加载

数据集类的基础方法

  • dataset:

需要包含数据迭代方法

def __getitem__(self, index):
        image, target = self.list(index)
        return image, target

利用torch.utils.data.DataLoader封装后,用于迭代遍历数据元素;

数据长度方法

def __len__(self):
    return self.dataset_size

数据加载

  • dataloader:
    对数据集类(通常实现了 getitemlen 方法)时,
    你可以使用 DataLoader 来轻松地进行批量加载、打乱数据、并行加载以及多进程数据加载。
    collate_fn:将字典或数组数据流进行拆分,拆分为图像、label、边界框、文字编码等不同类型数据与模型的输入与输出相匹配

2.1Dataloater参数

参数:

  • dataset (Dataset): 加载数据的数据集。
  • batch_size (int, 可选): 每批加载的样本数量(默认:1)。
  • shuffle (bool, 可选): 设置为 True 以在每个 epoch 重新洗牌数据(默认:False)。
  • sampler (Sampler 或 Iterable, 可选): 定义从数据集中抽取样本的策略。可以是任何实现了 len 的 Iterable。如果指定了 sampler,则不能指定 :attr:shuffle。
  • batch_sampler (Sampler 或 Iterable, 可选): 与 :attr:sampler 类似,但一次返回一批索引。与 :attr:batch_size, :attr:shuffle, :attr:sampler, 和 :attr:drop_last 互斥。
    num_workers (int, 可选): 数据加载使用的子进程数量。0 表示数据将在主进程中加载(默认:0)。
  • collate_fn (Callable, 可选): 将样本列表合并以形成 Tensor(s) 的 mini-batch。在使用 map-style 数据集的批量加载时使用。
  • pin_memory (bool, 可选): 如果设置为 True,则数据加载器将在返回它们之前将 Tensors 复制到设备/CUDA 固定内存中。如果你的数据元素是自定义类型,或者你的 :attr:collate_fn 返回的批次是自定义类型,请参见下面的例子。
  • drop_last (bool, 可选): 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批量大小整除。如果设置为 False 并且数据集大小不能被批量大小整除,则最后一个批次会较小(默认:False)。
  • timeout (numeric, 可选): 如果为正数,这是从工作进程收集一个批次的超时值。应始终为非负数(默认:0)。
  • worker_init_fn (Callable, 可选): 如果不是 None,这将在每个工作进程子进程上调用,输入为工作进程 id(一个在 [0, num_workers - 1] 中的 int),在设置种子和数据加载之前。(默认:None)
  • generator (torch.Generator, 可选): 如果不是 None,这个 RNG 将被 RandomSampler 用来生成随机索引,并被多进程用来为工作进程生成 base_seed。(默认:None)
  • prefetch_factor (int, 可选,关键字参数): 每个工作进程预先加载的批次数量。2 意味着将有总共 2*num_workers 个批次被预先加载。(默认值取决于 num_workers 的设定值。如果 num_workers=0,默认是 None。否则如果 num_workers>0,默认是 2)。
  • persistent_workers (bool, 可选): 如果设置为 True,则数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作进程的 Dataset 实例存活。(默认:False)。
  • pin_memory_device (str, 可选): 如果将 pin_memory 设置为 true,则数据加载器将在返回它们之前将 Tensors 复制到设备固定内存中。

2.2 collate_fn

class CollateFunc(object):
    def __call__(self, batch):
        targets = []
        images = []

        for sample in batch:
            image = sample[0]
            target = sample[1]

            images.append(image)
            targets.append(target)

        images = torch.stack(images, 0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云朵不吃雨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值