关于Pytorch-Lightning的优势和相关知识,下面的博客做了很好的介绍,因为本人课题需要,所以参考该博客尝试搭建一个基于Pytorch-Lightning的深度估计系统。Pytorch Lightning 完全攻略。
前言
使用Pytorch Lightning框架的深度估计系统是有的,比如SC-DpethV3,但是我要改进的系统其所用框架是Pytorch,当我参考SC-DpethV3来转换目标系统时,陷入了半懂半不懂的境地(可能是我本人之前没有接触过深度学习),因此我想从0自己搭一个基于Pytorch-Lightning的深度估计框架出来,也方便我下一步的课题开展。
Pytorch Lightning的基本架构
这里我使用的是Pytorch Lightning 完全攻略提出的通用架构,架构链接。在此十分感谢博主Takanashi的辛苦工作!
在README文件中作者对Pytorch Lightning以及其总结的模板进行了介绍,模板结构如下:

根目录下主要放两个文件:main.py和utils.py(辅助用),main.py包含三个功能:1、定义参数解析器parser,指定一些参数。2、callback函数设置:自动存档,Early Stop及LR Scheduler等,在pl.Trainer中会用到。3、将模型接口,数据集接口,训练器实例化。
data和modle两个文件夹中放入__init__.py文件,做成包,将接口类导入
from .data_interface import DInterface
from .model_interface import MInterface
这样做就可以再main.py中直接从文件夹导入接口类如from model import MInterface
data_interface.py和model_interface.py主要用来定义接口类:DInterface和MInterface,这两个接口可以控制不同的模型,不需要再修改,只需要修改传入的参数即可。
作者还提醒使用严格的命名-snakecase
一、main.py
本页结构清晰,有三部分:1、加载回调的函数load_callbacks() ;2、主函数main(args) ;3、程序入口if __name__ == '__main__':
**程序入口:**设置网络参数,调用main函数。
将参数修改为自己要设置的,参考的是目标系统,有些参数可能就不用,但是暂时先保留着。
将每个参数的含义注释清楚,将未用到的删除。
论文中的数学公式使用了SSIM(结构相似度),但是代码中却未使用SSIM。
关于dataloader workers:
1、每次dataloader加载数据时:dataloader一次性创建num_worker个worker,(也可以说dataloader一次性创建num_worker个工作进程,worker也是普通的工作进程),并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。
2、然后,dataloader从RAM中找本轮迭代要用的batch,如果找到了,就使用。如果没找到,就要num_worker个worker继续加载batch到内存,直到dataloader在RAM中找到目标batch。一般情况下都是能找到的,因为batch_sampler指定batch时当然优先指定本轮要用的batch。num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮…迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。
3、如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度更慢。最好设置为CPU的数量。
main函数:
设置随机种子,我看目标代码预训练的seed设置的是0.
pl.seed_everything(args.seed)
关键点路径的获取,好像是用来加载预训练的模型的。
load_model_path_by_args(args)
需要预训练模型的路径,模型的版本名称,模型的版本号,在scv3中先设置了logger日志器,再设置检查点的保存和加载。
因此参考scv3先设置日志记录器:日志的保存目录–ckpts,日志的名字
并在回调函数列表中添加ModelCheckpoint
从之前的检查点恢复:
litemono 是使用预训练的编码器(在imagenet上预训练过的需要下载),代码中load_model(self)加载的模型名称是测试和评估时在命令行特别注明的,就是训练好的模型权重;self.load_pretrain()函数使用的是在ImageNet上预先训练的主干(深度编码器)的权重–mypretrain。
scv3中的检查点恢复指的应该是整个系统的而不是某个编码器的。并且scv3的训练命令行中也没有指明检查点。对于我来说,这个检查点是不存在的,因为litemono的ckpts是分散的不是pl完整系统的。
将模板代码在的加载检查点路径函数删去,直接在参数列表设置,并使用系统直接加载,**后面的参数是否是必要的?**可以不加
model = MInterface(args)
model = model.load_from_checkpoint(args.ckpt_path, strict=False, args = args)
下面是训练器,如何将设置的检查点回调放到Trainer训练器中,从scv3代码看,直接将检查点函数放到Trainer参数callbacks(列表)中也是可以的所以我直接callbacks=callback_list ,对于其他参数limit_train_batches(每个epoch运行多少个batch,litemono有限制么?好像没有),limit_val_batches,num_sanity_val_steps(在开始训练之前,健康检查运行n批val。这可以捕获验证中的任何错误,而无需等待第一次验证检查。培训师默认使用2个步骤。在此处关闭或修改它。litemono未做设置)
关于Trainer更细节的需要前往官网了解。
下面就是正常的拟合
数据集接口data_interface.py
litemono是怎么加载自己的数据的?
1、通过字典的方式定义了两种数据类型kitti_raw和kitti_odom的处理类KITTIRAWDataset和KITTIOdomDataset(获得图像路径)
datasets_dict = {
"kitti": datasets.KITTIRAWDataset,
"kitti_odom": datasets.KITTIOdomDataset}
self.dataset = datasets_dict[self.opt.dataset]
KITTIRAWDataset的继承关系为:KITTIRAWDataset<-KITTIDataset<-MonoDataset<-data.Dataset
主要干了两个事,获得图像的路径,产生深度真值(利用激光雷达和标定参数)。如果是kiitidepth数据集可以直接获得真值深度图
根据在参数列表中的设置,确定是哪个处理类-这里确定是KITTIRAWDataset。
2、找到位于本地文件目录/splits/eigen_zhou/xxx_files.txt 的文件路径
fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")
3、确定训练集和验证集的文件名还有图像的格式:如果一个数据集没有提前分割好,在pl中应该如何设置?
train_filenames = readlines(fpath.format("train"))
val_filenames = readlines(fpath.format("val"))
img_ext = '.png' if self.opt.png else '.jpg'
4、确定train_files.txt中数据的训练样本数量以及总步数=训练样本数量/每批次样本数 x epochs的数量 也就是总共需要多少个batch
num_train_samples = len(train_filenames)
self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs
5、实例化KITTIRAWDataset和DataLoader:数据集的创建以及向训练传递数据
train_dataset = self.dataset(
self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
self.train_loader = DataLoader(
train_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
6、验证数据集的创建并向验证传递数据。内置函数iter()并将self.val_loader作为参数传递,从而创建一个迭代器,之后在代码中的val(self)函数会使用next()函数会调用迭代器的__next__()方法,从而返回迭代器中的下一个值。直到迭代器中的所有值都返回后,抛出一个StopIteration异常。
val_dataset = self.dataset(
self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
self.val_loader = DataLoader(
val_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
self.val_iter = iter(self.val_loader)
如何修改?scv3可适当参考,另外发现一篇博客讲解如何使用LightningDataModule的,LightningDataModule的使用
scv3首先进行初始化并保存了超参数,通过get_training_size获得training_size,指定图像的尺寸,并获取伪真值(为了设置伪真值自定义了一个训练集函数)。这个不好改啊。scv3提供伪真值的训练集了。
我想用scv3数据集来训练改完后的litemono(因为我要验证其对动态场景的稳定性),也就是说数据集接口部分主要参考scv3。
需要注意的是litemono对数据的要求有无冲突,litemono后续用train_loader 进行模型训练求解输出和损失,这些操作依赖数据的color_aug属性?这是一个在MonoDataset类中进行了色彩增强后的图像。能否在接口改为scv3后也添加一个这个属性?这个属性很重要么,训练部分就是将所有增强后的图像放入到网络模型中。
先看一下scv3数据集的结构分布吧

结构很清晰,主要包含两个大文件夹,Training和Testing。Training包含三类文件夹:1、场景;2、训练txt文本(跟litemono一样,包含训练所用的数据名);3、验证txt文本。
每个场景一堆有序图像,相机内参txt文件,真值深度文件夹,伪深度图文件夹
Testing包含三种文件夹:1、color用于测试的图像文件夹;2、真值深度文件夹;3、用于动态区域评估的语义分割掩膜。
修改数据接口:
1、初始化
将**vars(args)都替换为args。
保存超参数
self.save_hyperparameters()
litemono对图像尺寸的要求是height=192,width=640,必须是32的整数倍?跟全连接层有关,含有全连接层的网络必须将图像resize为特定尺寸,比如yolov5就是要求为32的整数倍。
参考scv3将获取图像尺寸函数写到utils.py中
将加载伪真值取消默认使用。
数据预处理:SCV3使用一系列自定义的图像处理防止训练过拟合,litemono也有数据增强部分。
self.train_transform = custom_transforms.Compose([
custom_transforms.RandomHorizontalFlip(),
custom_transforms.RandomScaleCrop(),
custom_transforms.RescaleTo(self.training_size),
custom_transforms.ArrayToTensor(),
custom_transforms.Normalize()]
)
SCV3使用了def __call__函数来执行custom_transforms.py中定义的类,按照列表中的顺序执行。
RandomHorizontalFlip:以0.5的概率随机水平翻转给定的numpy数组
RandomScaleCrop:随机将图像放大至15%,并将其裁剪为与以前相同的大小。
RescaleTo:重新缩放图像以进行训练或验证
ArrayToTensor:将numpy.ndarray(H x W x C)的列表以及内部矩阵转换为torch的列表。形状为(CxHxW)的具有本征张量的FloatTensor。
Normalize:将tensor中的每个元素减去self.mean,然后除以self.std。这是一个简单的批量标准化(Batch Normalization)操作,可以提高模型的性能和泛化能力。
使用scv3的数据预处理方法对数据预处理,这和transforms.Compose是一样的(不如自定义的灵活)
2、实例化数据集def setup(self, stage=None):
训练数据集的实例化:
self.train_dataset = TrainFolder(
self.hparams.hparams.dataset_dir,
train=True,
transform=self.train_transform,
sequence_length=self.hparams.hparams.sequence_length,
skip_frames=self.hparams.hparams.skip_frames,
use_frame_index=self.hparams.hparams.use_frame_index,
with_pseudo_depth=self.load_pseudo_depth
)
使用了类TrainFolder,继承于data.Dataset
需要参数包含:1、数据集路径;2、训练标识(真);3、图像变换器;3、用于训练的图像序列长度=3;4、跳过帧数为1;5、数据集的名称;6、使用帧索引;7、使用伪真值深度。
固定随机种子为0;
找到train.txt文件,并打开,循环文件中对应的数据名称,形成一个新的路径读取到一个列表self.scenes中。
将其他参数实例化
调用crawl_folders函数:1、循环self.scenes中的数据;2、将scene下cam.txt的内参转为矩阵;3、对图像进行排序imgs = sorted(scene.files('*.jpg')) ;4、如果使用了use_frame_index,则按照frame_index对图像再次排序imgs = [imgs[d] for d in frame_index] ;5、如果使用伪真值则对伪真值图像按frame_index进行排序pseudo_depths = [pseudo

博主因课题需要,尝试从0搭建基于Pytorch-Lightning的深度估计框架。介绍了Pytorch Lightning基本架构,包括main.py、数据集接口和模型接口等文件的设置与修改,还涉及数据加载、预处理、模型训练、改写模型接口、添加动态损失、测试代码编写及确定训练命令行等内容。
最低0.47元/天 解锁文章
6997






