pytorch方法合集(一)
model.apply(fn)
model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。
https://blog.youkuaiyun.com/qq_37025073/article/details/106739513
模型处理
https://blog.youkuaiyun.com/joseph__lagrange/article/details/109716583
torch.backends.cudnn.benchmark=True
在一般场景下,只要简单地在 PyTorch 程序开头将其值设置为 True,就可以大大提升卷积神经网络的运行速度。
https://zhuanlan.zhihu.com/p/73711222
storch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
每过step_size个epoch,做一次更新:
参数:
optimizer (Optimizer):要更改学习率的优化器;
step_size(int):每训练step_size个epoch,更新一次参数;
gamma(float):更新lr的乘法因子;
last_epoch (int):最后一个epoch的index,如果是训练了很多个epoch后中断了,继续训练,这个值就等于加载的模型的epoch。默认为-1表示从头开始训练,即从epoch=1开始。
https://blog.youkuaiyun.com/qyhaill/article/details/103043637
DataLoader
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
参数说明:
dataset (Dataset) – 数据集
batch_size (int, optional) – 一次性加载的样本数 (default: 1).
shuffle (bool, optional) – 是否打乱数据 (default: False).
sampler (Sampler or Iterable, optional) – 定义从数据集提取样本的策略。可以是任何实现了__len__的Iterable。如果指定了,则不能指定shuffle。
batch_sampler (Sampler or Iterable, optional) – 类似于 sampler, 但每次返回一批索引。 和 batch_size, shuffle, sampler, drop_last互相排斥。
num_workers (int, optional) – 要使用多少子进程来加载数据。 0 表示数据会被加载进主程序. (default: 0)
collate_fn (callable, optional) – 合并一个样本列表,形成一个小批张量。当从映射样式的数据集使用批处理加载时使用。
pin_memory (bool, optional) – 如果为True,数据加载器将在返回张量之前将张量复制到CUDA固定内存中。 如果您的数据元素是自定义类型,或者您的collate_fn返回一个自定义类型的批处理,请参见下面的示例。
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
https://blog.youkuaiyun.com/yangwangnndd/article/details/95385628
drop_last (bool, optional) – 如果数据集大小不能被批大小整除,则设置为True以删除最后一个不完整的批。. (default: False)
timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional) – 如果不是None,则在seeding之后和数据加载之前,每个worker id(在[0,num_workers - 1]中是一个int)作为输入的worker子进程都会调用这个函数。 (default: None)
generator (torch.Generator, optional) – 如果不是None,这个RNG将被RandomSampler用来生成随机索引和multiprocessing来为workers生成base_seed。 (default: None)
prefetch_factor (int, optional, keyword-only arg) – 每个worker提前装样的数量。2表示将在所有workers中预取2 * num_workers样本。(default: 2)
persistent_workers (bool, optional) – 如果为True,数据加载器将不会在数据集被消耗一次后关闭工作进程。 这允许实时维护workers Dataset实例。 (default: False)